From 0320bb7b35c31b4c875af30ccde017f68e26afb9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 26 Aug 2025 22:32:12 +0200 Subject: [PATCH 01/41] [management] Report sync duration and login duration by accountID (#4406) --- management/server/grpcserver.go | 14 ++++++++------ management/server/telemetry/grpc_metrics.go | 11 +++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index a637cf02d..41d463ae3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -198,7 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } unlock() @@ -436,11 +436,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() - defer func() { - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) - } - }() + if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequest() } @@ -463,6 +459,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + defer func() { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) + } + }() + if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index ac6ff2ea8..16ad5c61c 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -4,9 +4,12 @@ import ( "context" "time" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) +const AccountIDLabel = "account_id" + // GRPCMetrics are gRPC server metrics type GRPCMetrics struct { meter metric.Meter @@ -111,13 +114,13 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() { } // CountLoginRequestDuration counts the duration of the login gRPC requests -func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) { - grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) +func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { + grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds(), metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) } // CountSyncRequestDuration counts the duration of the sync gRPC requests -func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) { - grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) +func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { + grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds(), metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. From 7ce5507c05ef617c264f0dd2fb6aa57635b522f1 Mon Sep 17 00:00:00 2001 From: "Krzysztof Nazarewski (kdn)" Date: Wed, 27 Aug 2025 09:59:39 +0200 Subject: [PATCH 02/41] [client] fix darwin dns always throwing err (#4403) * fix: dns/host_darwin.go was missing if err != nil before throwing error --- client/internal/dns/host_darwin.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 852dfef48..b06ba73ab 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -166,9 +166,10 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { - err := s.recordSystemDNSSettings(true) - log.Errorf("Unable to get system DNS configuration") - return err + if err := s.recordSystemDNSSettings(true); err != nil { + log.Errorf("Unable to get system DNS configuration") + return fmt.Errorf("recordSystemDNSSettings(): %w", err) + } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { From 99bd34c02a1bcd45af5d5a18f5a1a3a92bafda90 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 27 Aug 2025 18:30:49 +0200 Subject: [PATCH 03/41] [signal] fix goroutines and memory leak on forward messages between peers (#3896) --- go.mod | 36 ++--- go.sum | 79 ++++----- management/server/types/network.go | 4 +- signal/cmd/run.go | 1 + signal/peer/peer.go | 48 +++--- signal/peer/peer_test.go | 251 +++++++++++++++++++++++++++-- signal/server/signal.go | 109 +++++++++---- 7 files changed, 402 insertions(+), 126 deletions(-) diff --git a/go.mod b/go.mod index 6beed2ff5..33aea7618 100644 --- a/go.mod +++ b/go.mod @@ -18,12 +18,12 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.3.0 - golang.org/x/crypto v0.37.0 - golang.org/x/sys v0.32.0 + golang.org/x/crypto v0.40.0 + golang.org/x/sys v0.34.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.64.1 + google.golang.org/grpc v1.73.0 google.golang.org/protobuf v1.36.6 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -93,18 +93,18 @@ require ( github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 - go.opentelemetry.io/otel v1.26.0 + go.opentelemetry.io/otel v1.35.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 - go.opentelemetry.io/otel/metric v1.26.0 - go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.opentelemetry.io/otel/metric v1.35.0 + go.opentelemetry.io/otel/sdk/metric v1.35.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.39.0 - golang.org/x/oauth2 v0.27.0 - golang.org/x/sync v0.13.0 - golang.org/x/term v0.31.0 + golang.org/x/net v0.42.0 + golang.org/x/oauth2 v0.28.0 + golang.org/x/sync v0.16.0 + golang.org/x/term v0.33.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -117,7 +117,7 @@ require ( require ( cloud.google.com/go/auth v0.3.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect - cloud.google.com/go/compute/metadata v0.3.0 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect dario.cat/mergo v1.0.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect @@ -232,19 +232,19 @@ require ( github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect - go.opentelemetry.io/otel/sdk v1.26.0 // indirect - go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.opentelemetry.io/otel/sdk v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/mod v0.25.0 // indirect + golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) diff --git a/go.sum b/go.sum index 5a8236332..562548edd 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUM cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqClKRT5SZwBmk= @@ -588,8 +588,8 @@ github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0 github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= github.com/rs/cors v1.8.0/go.mod h1:EBwu+T5AvHOcXwvZIkQFjUN6s8Czyqw12GL/Y0tUyRM= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= @@ -712,26 +712,28 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0/go.mod h1:Mjt1i1INqiaoZOMGR1RIUJN+i3ChKoFRqzrRQhlkbs0= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= -go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= -go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= -go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= -go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= -go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= -go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= -go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= -go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= -go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= -go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= @@ -759,8 +761,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -806,8 +808,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= +golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -853,8 +855,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -868,8 +870,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= -golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -883,8 +885,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -952,8 +954,8 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -961,8 +963,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= -golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= +golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= +golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -976,8 +978,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1040,8 +1042,8 @@ golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.8-0.20211022200916-316ba0b74098/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= +golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1124,10 +1126,11 @@ google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= -google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463 h1:hE3bRWtU6uceqlh4fhrSnUyjKHMKB9KrTLLG+bc0ddM= +google.golang.org/genproto/googleapis/api v0.0.0-20250324211829-b45e905df463/go.mod h1:U90ffi8eUL9MwPcrJylN5+Mk2v3vuPDptd5yyNUiRR8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 h1:pFyd6EwwL2TqFf8emdthzeX+gZE1ElRq3iM8pui4KBY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1148,8 +1151,8 @@ google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= +google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/management/server/types/network.go b/management/server/types/network.go index f072a4294..ffc019565 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -12,11 +12,11 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" ) const ( diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 2e89b491a..1d76fa4e4 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/http" + // nolint:gosec _ "net/http/pprof" "strings" diff --git a/signal/peer/peer.go b/signal/peer/peer.go index f21c95a41..c9dd60fc0 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -5,10 +5,16 @@ import ( "sync" "time" + "errors" + log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/signal/metrics" +) + +var ( + ErrPeerAlreadyRegistered = errors.New("peer already registered") ) // Peer representation of a connected Peer @@ -23,15 +29,18 @@ type Peer struct { // registration time RegisteredAt time.Time + + Cancel context.CancelFunc } // NewPeer creates a new instance of a connected Peer -func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { +func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer { return &Peer{ Id: id, Stream: stream, StreamID: time.Now().UnixNano(), RegisteredAt: time.Now(), + Cancel: cancel, } } @@ -69,20 +78,24 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool { } // Register registers peer in the registry -func (registry *Registry) Register(peer *Peer) { +func (registry *Registry) Register(peer *Peer) error { start := time.Now() - registry.regMutex.Lock() - defer registry.regMutex.Unlock() - // can be that peer already exists, but it is fine (e.g. reconnect) p, loaded := registry.Peers.LoadOrStore(peer.Id, peer) if loaded { pp := p.(*Peer) - log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", - peer.Id, peer.StreamID, pp.StreamID) - registry.Peers.Store(peer.Id, peer) - return + if peer.StreamID > pp.StreamID { + log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", + peer.Id, peer.StreamID, pp.StreamID) + if swapped := registry.Peers.CompareAndSwap(peer.Id, pp, peer); !swapped { + return registry.Register(peer) + } + pp.Cancel() + log.Debugf("peer re-registered [%s]", peer.Id) + return nil + } + return ErrPeerAlreadyRegistered } log.Debugf("peer registered [%s]", peer.Id) @@ -92,22 +105,13 @@ func (registry *Registry) Register(peer *Peer) { registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) registry.metrics.Registrations.Add(context.Background(), 1) + + return nil } // Deregister Peer from the Registry (usually once it disconnects) func (registry *Registry) Deregister(peer *Peer) { - registry.regMutex.Lock() - defer registry.regMutex.Unlock() - - p, loaded := registry.Peers.LoadAndDelete(peer.Id) - if loaded { - pp := p.(*Peer) - if peer.StreamID < pp.StreamID { - registry.Peers.Store(peer.Id, p) - log.Debugf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.", - peer.Id, pp.StreamID, peer.StreamID) - return - } + if deleted := registry.Peers.CompareAndDelete(peer.Id, peer); deleted { registry.metrics.ActivePeers.Add(context.Background(), -1) log.Debugf("peer deregistered [%s]", peer.Id) registry.metrics.Deregistrations.Add(context.Background(), 1) diff --git a/signal/peer/peer_test.go b/signal/peer/peer_test.go index fb85fedda..6b7976eb4 100644 --- a/signal/peer/peer_test.go +++ b/signal/peer/peer_test.go @@ -1,13 +1,18 @@ package peer import ( + "context" + "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" ) @@ -19,12 +24,16 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T) peerID := "peer" - olderPeer := NewPeer(peerID, nil) - r.Register(olderPeer) + _, cancel1 := context.WithCancel(context.Background()) + olderPeer := NewPeer(peerID, nil, cancel1) + err = r.Register(olderPeer) + require.NoError(t, err) time.Sleep(time.Nanosecond) - newerPeer := NewPeer(peerID, nil) - r.Register(newerPeer) + _, cancel2 := context.WithCancel(context.Background()) + newerPeer := NewPeer(peerID, nil, cancel2) + err = r.Register(newerPeer) + require.NoError(t, err) registered, _ := r.Get(olderPeer.Id) assert.NotNil(t, registered, "peer can't be nil") @@ -59,10 +68,14 @@ func TestRegistry_Register(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - peer1 := NewPeer("test_peer_1", nil) - peer2 := NewPeer("test_peer_2", nil) - r.Register(peer1) - r.Register(peer2) + _, cancel1 := context.WithCancel(context.Background()) + peer1 := NewPeer("test_peer_1", nil, cancel1) + _, cancel2 := context.WithCancel(context.Background()) + peer2 := NewPeer("test_peer_2", nil, cancel2) + err = r.Register(peer1) + require.NoError(t, err) + err = r.Register(peer2) + require.NoError(t, err) if _, ok := r.Get("test_peer_1"); !ok { t.Errorf("expected test_peer_1 not found in the registry") @@ -78,10 +91,14 @@ func TestRegistry_Deregister(t *testing.T) { require.NoError(t, err) r := NewRegistry(metrics) - peer1 := NewPeer("test_peer_1", nil) - peer2 := NewPeer("test_peer_2", nil) - r.Register(peer1) - r.Register(peer2) + _, cancel1 := context.WithCancel(context.Background()) + peer1 := NewPeer("test_peer_1", nil, cancel1) + _, cancel2 := context.WithCancel(context.Background()) + peer2 := NewPeer("test_peer_2", nil, cancel2) + err = r.Register(peer1) + require.NoError(t, err) + err = r.Register(peer2) + require.NoError(t, err) r.Deregister(peer1) @@ -94,3 +111,213 @@ func TestRegistry_Deregister(t *testing.T) { } } + +func TestRegistry_MultipleRegister_Concurrency(t *testing.T) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + + numGoroutines := 1000 + + ids := make(chan int64, numGoroutines) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + peerID := "peer-concurrent" + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + _, cancel := context.WithCancel(context.Background()) + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + ids <- peer.StreamID + }(i) + } + + wg.Wait() + close(ids) + maxId := int64(0) + for id := range ids { + maxId = max(maxId, id) + } + + peer, ok := registry.Get(peerID) + require.True(t, ok, "expected peer to be registered") + require.Equal(t, maxId, peer.StreamID, "expected the highest StreamID to be registered") +} + +func Benchmark_MultipleRegister_Concurrency(b *testing.B) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(b, err) + + numGoroutines := 1000 + + var wg sync.WaitGroup + peerID := "peer-concurrent" + _, cancel := context.WithCancel(context.Background()) + b.Run("multiple-register", func(b *testing.B) { + registry := NewRegistry(metrics) + b.ResetTimer() + for j := 0; j < b.N; j++ { + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + }(i) + } + wg.Wait() + } + }) +} + +func TestRegistry_MultipleDeregister_Concurrency(t *testing.T) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + + numGoroutines := 1000 + + ids := make(chan int64, numGoroutines) + + var wg sync.WaitGroup + wg.Add(numGoroutines) + peerID := "peer-concurrent" + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + _, cancel := context.WithCancel(context.Background()) + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + ids <- peer.StreamID + registry.Deregister(peer) + }(i) + } + + wg.Wait() + close(ids) + maxId := int64(0) + for id := range ids { + maxId = max(maxId, id) + } + + _, ok := registry.Get(peerID) + require.False(t, ok, "expected peer to be deregistered") +} + +func Benchmark_MultipleDeregister_Concurrency(b *testing.B) { + + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(b, err) + + numGoroutines := 1000 + + var wg sync.WaitGroup + peerID := "peer-concurrent" + _, cancel := context.WithCancel(context.Background()) + b.Run("register-deregister", func(b *testing.B) { + registry := NewRegistry(metrics) + b.ResetTimer() + for j := 0; j < b.N; j++ { + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(routineIndex int) { + defer wg.Done() + + peer := NewPeer(peerID, nil, cancel) + _ = registry.Register(peer) + time.Sleep(time.Nanosecond) + registry.Deregister(peer) + }(i) + } + wg.Wait() + } + }) +} + +type mockConnectStreamServer struct { + grpc.ServerStream + ctx context.Context +} + +func (m *mockConnectStreamServer) Context() context.Context { + return m.ctx +} + +func (m *mockConnectStreamServer) SendHeader(md metadata.MD) error { + return nil +} + +func (m *mockConnectStreamServer) Send(msg *proto.EncryptedMessage) error { + return nil +} + +func (m *mockConnectStreamServer) Recv() (*proto.EncryptedMessage, error) { + <-m.ctx.Done() + return nil, m.ctx.Err() +} + +func TestReconnectHandling(t *testing.T) { + metrics, err := metrics.NewAppMetrics(otel.Meter("")) + require.NoError(t, err) + registry := NewRegistry(metrics) + peerID := "test-peer-reconnect" + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + stream1 := &mockConnectStreamServer{ctx: ctx1} + peer1 := NewPeer(peerID, stream1, cancel1) + + err = registry.Register(peer1) + require.NoError(t, err, "first registration should succeed") + + p, found := registry.Get(peerID) + require.True(t, found, "peer should be found in the registry") + require.Equal(t, peer1.StreamID, p.StreamID, "StreamID of registered peer should match") + + time.Sleep(time.Nanosecond) + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + stream2 := &mockConnectStreamServer{ctx: ctx2} + peer2 := NewPeer(peerID, stream2, cancel2) + + err = registry.Register(peer2) + require.NoError(t, err, "reconnect registration should succeed") + + select { + case <-ctx1.Done(): + require.ErrorIs(t, ctx1.Err(), context.Canceled, "context of old stream should be canceled after successful reconnection") + case <-time.After(100 * time.Millisecond): + t.Fatal("context of old stream was not canceled after reconnection") + } + + p, found = registry.Get(peerID) + require.True(t, found) + require.Equal(t, peer2.StreamID, p.StreamID, "registered peer should have the new StreamID after reconnection") + + ctx3, cancel3 := context.WithCancel(context.Background()) + defer cancel3() + stream3 := &mockConnectStreamServer{ctx: ctx3} + stalePeer := NewPeer(peerID, stream3, cancel3) + stalePeer.StreamID = peer1.StreamID + + err = registry.Register(stalePeer) + require.ErrorIs(t, err, ErrPeerAlreadyRegistered, "reconnecting with an old StreamID should return an error") + + p, found = registry.Get(peerID) + require.True(t, found) + require.Equal(t, peer2.StreamID, p.StreamID, "active peer should still be the one with the latest StreamID") + + select { + case <-ctx2.Done(): + t.Fatal("context of the new stream should not be canceled after trying to register with an old StreamID") + default: + } +} diff --git a/signal/server/signal.go b/signal/server/signal.go index 8ae14822b..47f01edae 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -2,7 +2,9 @@ package server import ( "context" + "errors" "fmt" + "os" "time" log "github.com/sirupsen/logrus" @@ -15,9 +17,9 @@ import ( "github.com/netbirdio/signal-dispatcher/dispatcher" + "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" - "github.com/netbirdio/netbird/shared/signal/proto" ) const ( @@ -27,6 +29,8 @@ const ( labelTypeNotRegistered = "not_registered" labelTypeStream = "stream" labelTypeMessage = "message" + labelTypeTimeout = "timeout" + labelTypeDisconnected = "disconnected" labelError = "error" labelErrorMissingId = "missing_id" @@ -37,6 +41,12 @@ const ( labelRegistrationStatus = "status" labelRegistrationFound = "found" labelRegistrationNotFound = "not_found" + + sendTimeout = 10 * time.Second +) + +var ( + ErrPeerRegisteredAgain = errors.New("peer registered again") ) // Server an instance of a Signal server @@ -45,6 +55,10 @@ type Server struct { proto.UnimplementedSignalExchangeServer dispatcher *dispatcher.Dispatcher metrics *metrics.AppMetrics + + successHeader metadata.MD + + sendTimeout time.Duration } // NewServer creates a new Signal server @@ -59,10 +73,19 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating dispatcher: %v", err) } + sTimeout := sendTimeout + to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT") + if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 { + log.Trace("using custom send timeout ", parsed) + sTimeout = parsed + } + s := &Server{ - dispatcher: d, - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, + dispatcher: d, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, + successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), + sendTimeout: sTimeout, } return s, nil @@ -82,7 +105,8 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - p, err := s.RegisterPeer(stream) + ctx, cancel := context.WithCancel(context.Background()) + p, err := s.RegisterPeer(stream, cancel) if err != nil { return err } @@ -90,8 +114,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) defer s.DeregisterPeer(p) // needed to confirm that the peer has been registered so that the client can proceed - header := metadata.Pairs(proto.HeaderRegistered, "1") - err = stream.SendHeader(header) + err = stream.SendHeader(s.successHeader) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader))) return err @@ -99,27 +122,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) - <-stream.Context().Done() - log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) - return nil + select { + case <-stream.Context().Done(): + log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID) + return nil + case <-ctx.Done(): + return ErrPeerRegisteredAgain + } } -func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { +func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) { log.Debugf("registering new peer") - meta, hasMeta := metadata.FromIncomingContext(stream.Context()) - if !hasMeta { - s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) - return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") - } - - id, found := meta[proto.HeaderId] - if !found { + id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId) + if id == nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId) } - p := peer.NewPeer(id[0], stream) - s.registry.Register(p) + p := peer.NewPeer(id[0], stream, cancel) + if err := s.registry.Register(p); err != nil { + return nil, err + } err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) if err != nil { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration))) @@ -131,8 +154,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) ( func (s *Server) DeregisterPeer(p *peer.Peer) { log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) - s.registry.Deregister(p) s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) + s.registry.Deregister(p) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { @@ -145,7 +168,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM if !found { s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) + log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) // todo respond to the sender? return } @@ -153,16 +176,34 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) start := time.Now() - // forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - // todo respond to the sender? - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - return - } + sendResultChan := make(chan error, 1) + go func() { + select { + case sendResultChan <- dstPeer.Stream.Send(msg): + return + case <-dstPeer.Stream.Context().Done(): + return + } + }() - // in milliseconds - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(ctx, 1) - s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) + select { + case err := <-sendResultChan: + if err != nil { + log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + return + } + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) + s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) + + case <-dstPeer.Stream.Context().Done(): + log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected))) + + case <-time.After(s.sendTimeout): + dstPeer.Cancel() // cancel the peer context to trigger deregistration + log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout))) + } } From aa595c3073231cc0aa33e7d4d9581181129c8bb3 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:25:16 +0200 Subject: [PATCH 04/41] [client] Fix shared sock buffer allocation (#4409) --- sharedsock/sock_linux.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index db428515b..d4fedc492 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -230,10 +230,8 @@ func (s *SharedSocket) Close() error { // read start a read loop for a specific receiver and sends the packet to the packetDemux channel func (s *SharedSocket) read(receiver receiver) { - // Buffer reuse is safe: packetDemux is unbuffered, so read() blocks until - // ReadFrom() synchronously processes the packet before next iteration - buf := make([]byte, s.mtu+maxIPUDPOverhead) for { + buf := make([]byte, s.mtu+maxIPUDPOverhead) n, addr, err := receiver(s.ctx, buf, 0) select { case <-s.ctx.Done(): From 4fd10b9447edb0eb5c2a505320e5bb479ac70423 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 28 Aug 2025 13:25:40 +0200 Subject: [PATCH 05/41] [management] split high latency grpc metrics (#4408) --- management/server/telemetry/grpc_metrics.go | 67 +++++++++++++++------ 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index 16ad5c61c..9840b4b32 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -9,18 +9,21 @@ import ( ) const AccountIDLabel = "account_id" +const HighLatencyThreshold = time.Second * 7 // GRPCMetrics are gRPC server metrics type GRPCMetrics struct { - meter metric.Meter - syncRequestsCounter metric.Int64Counter - loginRequestsCounter metric.Int64Counter - getKeyRequestsCounter metric.Int64Counter - activeStreamsGauge metric.Int64ObservableGauge - syncRequestDuration metric.Int64Histogram - loginRequestDuration metric.Int64Histogram - channelQueueLength metric.Int64Histogram - ctx context.Context + meter metric.Meter + syncRequestsCounter metric.Int64Counter + syncRequestHighLatencyCounter metric.Int64Counter + loginRequestsCounter metric.Int64Counter + loginRequestHighLatencyCounter metric.Int64Counter + getKeyRequestsCounter metric.Int64Counter + activeStreamsGauge metric.Int64ObservableGauge + syncRequestDuration metric.Int64Histogram + loginRequestDuration metric.Int64Histogram + channelQueueLength metric.Int64Histogram + ctx context.Context } // NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server @@ -33,6 +36,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), + ) + if err != nil { + return nil, err + } + loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"), @@ -41,6 +52,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"), + ) + if err != nil { + return nil, err + } + getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter", metric.WithUnit("1"), metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"), @@ -86,15 +105,17 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro } return &GRPCMetrics{ - meter: meter, - syncRequestsCounter: syncRequestsCounter, - loginRequestsCounter: loginRequestsCounter, - getKeyRequestsCounter: getKeyRequestsCounter, - activeStreamsGauge: activeStreamsGauge, - syncRequestDuration: syncRequestDuration, - loginRequestDuration: loginRequestDuration, - channelQueueLength: channelQueue, - ctx: ctx, + meter: meter, + syncRequestsCounter: syncRequestsCounter, + syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, + loginRequestsCounter: loginRequestsCounter, + loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, + getKeyRequestsCounter: getKeyRequestsCounter, + activeStreamsGauge: activeStreamsGauge, + syncRequestDuration: syncRequestDuration, + loginRequestDuration: loginRequestDuration, + channelQueueLength: channelQueue, + ctx: ctx, }, err } @@ -115,12 +136,18 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() { // CountLoginRequestDuration counts the duration of the login gRPC requests func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { - grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds(), metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + if duration > HighLatencyThreshold { + grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + } } // CountSyncRequestDuration counts the duration of the sync gRPC requests func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) { - grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds(), metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) + if duration > HighLatencyThreshold { + grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID))) + } } // RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge. From dbefa8bd9fd365124875fd05a8bb471b42bda1d8 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:50:12 +0200 Subject: [PATCH 06/41] [management] remove lock and continue user update on failure (#4410) --- management/server/group.go | 89 ++++++++++++++++++++++---------------- management/server/user.go | 43 +++++++++++------- 2 files changed, 79 insertions(+), 53 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 86bc0d8a0..487cb6d97 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -202,35 +202,45 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { + var globalErr error + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) + + if err = transaction.CreateGroup(ctx, newGroup); err != nil { + return err + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return err + } + groupIDs = append(groupIDs, newGroup.ID) events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + return nil + }) if err != nil { - return err + log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) + if len(groupIDs) == 1 { + return err + } + globalErr = errors.Join(globalErr, err) + // continue updating other groups } + } - if err = transaction.CreateGroups(ctx, accountID, groupsToSave); err != nil { - return err - } - - return transaction.IncrementNetworkSerial(ctx, accountID) - }) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err } @@ -243,7 +253,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return nil + return globalErr } // UpdateGroups updates groups in the account. @@ -260,35 +270,45 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - groupIDs := make([]string, 0, len(groups)) - for _, newGroup := range groups { + var globalErr error + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { return err } newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) - groupIDs = append(groupIDs, newGroup.ID) + + if err = transaction.UpdateGroup(ctx, newGroup); err != nil { + return err + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return err + } events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + groupIDs = append(groupIDs, newGroup.ID) + + return nil + }) if err != nil { - return err + log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err) + if len(groups) == 1 { + return err + } + globalErr = errors.Join(globalErr, err) + // continue updating other groups } + } - if err = transaction.UpdateGroups(ctx, accountID, groupsToSave); err != nil { - return err - } - - return transaction.IncrementNetworkSerial(ctx, accountID) - }) + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) if err != nil { return err } @@ -301,7 +321,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us am.UpdateAccountPeers(ctx, accountID) } - return nil + return globalErr } // prepareGroupEvents prepares a list of event functions to be stored. @@ -584,13 +604,6 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st newGroup.ID = xid.New().String() } - for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) - if err != nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } - } - return nil } diff --git a/management/server/user.go b/management/server/user.go index 4596ee95b..e5a4dbcea 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -519,33 +519,46 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUser = result } - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - for _, update := range updates { - if update == nil { - return status.Errorf(status.InvalidArgument, "provided user update is nil") - } + var globalErr error + for _, update := range updates { + if update == nil { + return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate( ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings, ) if err != nil { return fmt.Errorf("failed to process update for user %s: %w", update.Id, err) } - usersToSave = append(usersToSave, updatedUser) - addUserEvents = append(addUserEvents, userEvents...) - peersToExpire = append(peersToExpire, userPeersToExpire...) if userHadPeers { updateAccountPeers = true } + + err = transaction.SaveUser(ctx, updatedUser) + if err != nil { + return fmt.Errorf("failed to save updated user %s: %w", update.Id, err) + } + + usersToSave = append(usersToSave, updatedUser) + addUserEvents = append(addUserEvents, userEvents...) + peersToExpire = append(peersToExpire, userPeersToExpire...) + + return nil + }) + if err != nil { + log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err) + if len(updates) == 1 { + return nil, err + } + globalErr = errors.Join(globalErr, err) + // continue when updating multiple users } - return transaction.SaveUsers(ctx, usersToSave) - }) - if err != nil { - return nil, err } - var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates)) + var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave)) userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID) if err != nil { @@ -578,7 +591,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, am.UpdateAccountPeers(ctx, accountID) } - return updatedUsersInfo, nil + return updatedUsersInfo, globalErr } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. @@ -643,7 +656,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } transferredOwnerRole = result - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id) if err != nil { return false, nil, nil, nil, err } From d4c067f0af564f14d925fb74f09e18335cb5a6f5 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 29 Aug 2025 17:40:05 +0200 Subject: [PATCH 07/41] [client] Don't deactivate upstream resolvers on failure (#4128) --- client/internal/connect.go | 3 +- client/internal/dns/config/domains.go | 201 +++++++++ client/internal/dns/config/domains_test.go | 213 +++++++++ client/internal/dns/handler_chain.go | 16 +- client/internal/dns/local/local.go | 2 +- client/internal/dns/mgmt/mgmt.go | 360 +++++++++++++++ client/internal/dns/mgmt/mgmt_test.go | 416 ++++++++++++++++++ client/internal/dns/mock_server.go | 24 +- client/internal/dns/server.go | 103 +++-- client/internal/dns/server_test.go | 24 +- client/internal/dns/upstream.go | 246 +++++++---- client/internal/dns/upstream_android.go | 9 +- client/internal/dns/upstream_general.go | 5 +- client/internal/dns/upstream_ios.go | 7 +- client/internal/dns/upstream_test.go | 32 +- client/internal/engine.go | 46 +- client/internal/engine_test.go | 6 +- client/internal/login.go | 18 +- .../routemanager/dnsinterceptor/handler.go | 34 +- 19 files changed, 1598 insertions(+), 167 deletions(-) create mode 100644 client/internal/dns/config/domains.go create mode 100644 client/internal/dns/config/domains_test.go create mode 100644 client/internal/dns/mgmt/mgmt.go create mode 100644 client/internal/dns/mgmt/mgmt_test.go diff --git a/client/internal/connect.go b/client/internal/connect.go index c2b7db0e2..f20b8d361 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -275,11 +275,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engineMutex.Unlock() - if err := c.engine.Start(); err != nil { + if err := c.engine.Start(loginResp.GetNetbirdConfig(), c.config.ManagementURL); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } + log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) diff --git a/client/internal/dns/config/domains.go b/client/internal/dns/config/domains.go new file mode 100644 index 000000000..cb651f1e5 --- /dev/null +++ b/client/internal/dns/config/domains.go @@ -0,0 +1,201 @@ +package config + +import ( + "errors" + "fmt" + "net" + "net/netip" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/shared/management/domain" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +var ( + ErrEmptyURL = errors.New("empty URL") + ErrEmptyHost = errors.New("empty host") + ErrIPNotAllowed = errors.New("IP address not allowed") +) + +// ServerDomains represents the management server domains extracted from NetBird configuration +type ServerDomains struct { + Signal domain.Domain + Relay []domain.Domain + Flow domain.Domain + Stuns []domain.Domain + Turns []domain.Domain +} + +// ExtractFromNetbirdConfig extracts domain information from NetBird protobuf configuration +func ExtractFromNetbirdConfig(config *mgmProto.NetbirdConfig) ServerDomains { + if config == nil { + return ServerDomains{} + } + + domains := ServerDomains{} + + domains.Signal = extractSignalDomain(config) + domains.Relay = extractRelayDomains(config) + domains.Flow = extractFlowDomain(config) + domains.Stuns = extractStunDomains(config) + domains.Turns = extractTurnDomains(config) + + return domains +} + +// ExtractValidDomain extracts a valid domain from a URL, filtering out IP addresses +func ExtractValidDomain(rawURL string) (domain.Domain, error) { + if rawURL == "" { + return "", ErrEmptyURL + } + + parsedURL, err := url.Parse(rawURL) + if err == nil { + if domain, err := extractFromParsedURL(parsedURL); err != nil || domain != "" { + return domain, err + } + } + + return extractFromRawString(rawURL) +} + +// extractFromParsedURL handles domain extraction from successfully parsed URLs +func extractFromParsedURL(parsedURL *url.URL) (domain.Domain, error) { + if parsedURL.Hostname() != "" { + return extractDomainFromHost(parsedURL.Hostname()) + } + + if parsedURL.Opaque == "" || parsedURL.Scheme == "" { + return "", nil + } + + // Handle URLs with opaque content (e.g., stun:host:port) + if strings.Contains(parsedURL.Scheme, ".") { + // This is likely "domain.com:port" being parsed as scheme:opaque + reconstructed := parsedURL.Scheme + ":" + parsedURL.Opaque + if host, _, err := net.SplitHostPort(reconstructed); err == nil { + return extractDomainFromHost(host) + } + return extractDomainFromHost(parsedURL.Scheme) + } + + // Valid scheme with opaque content (e.g., stun:host:port) + host := parsedURL.Opaque + if queryIndex := strings.Index(host, "?"); queryIndex > 0 { + host = host[:queryIndex] + } + + if hostOnly, _, err := net.SplitHostPort(host); err == nil { + return extractDomainFromHost(hostOnly) + } + + return extractDomainFromHost(host) +} + +// extractFromRawString handles domain extraction when URL parsing fails or returns no results +func extractFromRawString(rawURL string) (domain.Domain, error) { + if host, _, err := net.SplitHostPort(rawURL); err == nil { + return extractDomainFromHost(host) + } + + return extractDomainFromHost(rawURL) +} + +// extractDomainFromHost extracts domain from a host string, filtering out IP addresses +func extractDomainFromHost(host string) (domain.Domain, error) { + if host == "" { + return "", ErrEmptyHost + } + + if _, err := netip.ParseAddr(host); err == nil { + return "", fmt.Errorf("%w: %s", ErrIPNotAllowed, host) + } + + d, err := domain.FromString(host) + if err != nil { + return "", fmt.Errorf("invalid domain: %v", err) + } + + return d, nil +} + +// extractSingleDomain extracts a single domain from a URL with error logging +func extractSingleDomain(url, serviceType string) domain.Domain { + if url == "" { + return "" + } + + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + return "" + } + + return d +} + +// extractMultipleDomains extracts multiple domains from URLs with error logging +func extractMultipleDomains(urls []string, serviceType string) []domain.Domain { + var domains []domain.Domain + for _, url := range urls { + if url == "" { + continue + } + d, err := ExtractValidDomain(url) + if err != nil { + log.Debugf("Skipping %s: %v", serviceType, err) + continue + } + domains = append(domains, d) + } + return domains +} + +// extractSignalDomain extracts the signal domain from NetBird configuration. +func extractSignalDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Signal != nil { + return extractSingleDomain(config.Signal.Uri, "signal") + } + return "" +} + +// extractRelayDomains extracts relay server domains from NetBird configuration. +func extractRelayDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + if config.Relay != nil { + return extractMultipleDomains(config.Relay.Urls, "relay") + } + return nil +} + +// extractFlowDomain extracts the traffic flow domain from NetBird configuration. +func extractFlowDomain(config *mgmProto.NetbirdConfig) domain.Domain { + if config.Flow != nil { + return extractSingleDomain(config.Flow.Url, "flow") + } + return "" +} + +// extractStunDomains extracts STUN server domains from NetBird configuration. +func extractStunDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, stun := range config.Stuns { + if stun != nil && stun.Uri != "" { + urls = append(urls, stun.Uri) + } + } + return extractMultipleDomains(urls, "STUN") +} + +// extractTurnDomains extracts TURN server domains from NetBird configuration. +func extractTurnDomains(config *mgmProto.NetbirdConfig) []domain.Domain { + var urls []string + for _, turn := range config.Turns { + if turn != nil && turn.HostConfig != nil && turn.HostConfig.Uri != "" { + urls = append(urls, turn.HostConfig.Uri) + } + } + return extractMultipleDomains(urls, "TURN") +} diff --git a/client/internal/dns/config/domains_test.go b/client/internal/dns/config/domains_test.go new file mode 100644 index 000000000..5eae3a541 --- /dev/null +++ b/client/internal/dns/config/domains_test.go @@ -0,0 +1,213 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExtractValidDomain(t *testing.T) { + tests := []struct { + name string + url string + expected string + expectError bool + }{ + { + name: "HTTPS URL with port", + url: "https://api.netbird.io:443", + expected: "api.netbird.io", + }, + { + name: "HTTP URL without port", + url: "http://signal.example.com", + expected: "signal.example.com", + }, + { + name: "Host with port (no scheme)", + url: "signal.netbird.io:443", + expected: "signal.netbird.io", + }, + { + name: "STUN URL", + url: "stun:stun.netbird.io:443", + expected: "stun.netbird.io", + }, + { + name: "STUN URL with different port", + url: "stun:stun.netbird.io:5555", + expected: "stun.netbird.io", + }, + { + name: "TURNS URL with query params", + url: "turns:turn.netbird.io:443?transport=tcp", + expected: "turn.netbird.io", + }, + { + name: "TURN URL", + url: "turn:turn.example.com:3478", + expected: "turn.example.com", + }, + { + name: "REL URL", + url: "rel://relay.example.com:443", + expected: "relay.example.com", + }, + { + name: "RELS URL", + url: "rels://relay.netbird.io:443", + expected: "relay.netbird.io", + }, + { + name: "Raw hostname", + url: "example.org", + expected: "example.org", + }, + { + name: "IP address should be rejected", + url: "192.168.1.1", + expectError: true, + }, + { + name: "IP address with port should be rejected", + url: "192.168.1.1:443", + expectError: true, + }, + { + name: "IPv6 address should be rejected", + url: "2001:db8::1", + expectError: true, + }, + { + name: "HTTP URL with IPv4 should be rejected", + url: "http://192.168.1.1:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv4 should be rejected", + url: "https://10.0.0.1:443", + expectError: true, + }, + { + name: "STUN URL with IPv4 should be rejected", + url: "stun:192.168.1.1:3478", + expectError: true, + }, + { + name: "TURN URL with IPv4 should be rejected", + url: "turn:10.0.0.1:3478", + expectError: true, + }, + { + name: "TURNS URL with IPv4 should be rejected", + url: "turns:172.16.0.1:5349", + expectError: true, + }, + { + name: "HTTP URL with IPv6 should be rejected", + url: "http://[2001:db8::1]:8080", + expectError: true, + }, + { + name: "HTTPS URL with IPv6 should be rejected", + url: "https://[::1]:443", + expectError: true, + }, + { + name: "STUN URL with IPv6 should be rejected", + url: "stun:[2001:db8::1]:3478", + expectError: true, + }, + { + name: "IPv6 with port should be rejected", + url: "[2001:db8::1]:443", + expectError: true, + }, + { + name: "Localhost IPv4 should be rejected", + url: "127.0.0.1:8080", + expectError: true, + }, + { + name: "Localhost IPv6 should be rejected", + url: "[::1]:443", + expectError: true, + }, + { + name: "REL URL with IPv4 should be rejected", + url: "rel://192.168.1.1:443", + expectError: true, + }, + { + name: "RELS URL with IPv4 should be rejected", + url: "rels://10.0.0.1:443", + expectError: true, + }, + { + name: "Empty URL", + url: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ExtractValidDomain(tt.url) + + if tt.expectError { + assert.Error(t, err, "Expected error for URL: %s", tt.url) + } else { + assert.NoError(t, err, "Unexpected error for URL: %s", tt.url) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for URL: %s", tt.url) + } + }) + } +} + +func TestExtractDomainFromHost(t *testing.T) { + tests := []struct { + name string + host string + expected string + expectError bool + }{ + { + name: "Valid domain", + host: "example.com", + expected: "example.com", + }, + { + name: "Subdomain", + host: "api.example.com", + expected: "api.example.com", + }, + { + name: "IPv4 address", + host: "192.168.1.1", + expectError: true, + }, + { + name: "IPv6 address", + host: "2001:db8::1", + expectError: true, + }, + { + name: "Empty host", + host: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := extractDomainFromHost(tt.host) + + if tt.expectError { + assert.Error(t, err, "Expected error for host: %s", tt.host) + } else { + assert.NoError(t, err, "Unexpected error for host: %s", tt.host) + assert.Equal(t, tt.expected, result.SafeString(), "Domain mismatch for host: %s", tt.host) + } + }) + } +} diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 439bcbb3c..2e54bffd9 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,11 +11,12 @@ import ( ) const ( - PriorityLocal = 100 - PriorityDNSRoute = 75 - PriorityUpstream = 50 - PriorityDefault = 1 - PriorityFallback = -100 + PriorityMgmtCache = 150 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 + PriorityFallback = -100 ) type SubdomainMatcher interface { @@ -182,7 +183,10 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants to continue, try next handler if chainWriter.shouldContinue { - log.Tracef("handler requested continue to next handler for domain=%s", qname) + // Only log continue for non-management cache handlers to reduce noise + if entry.Priority != PriorityMgmtCache { + log.Tracef("handler requested continue to next handler for domain=%s", qname) + } continue } return diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index b776fbbe3..bac7875ec 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -34,7 +34,7 @@ func (d *Resolver) MatchSubdomains() bool { // String returns a string representation of the local resolver func (d *Resolver) String() string { - return fmt.Sprintf("local resolver [%d records]", len(d.records)) + return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) } func (d *Resolver) Stop() {} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go new file mode 100644 index 000000000..290395473 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt.go @@ -0,0 +1,360 @@ +package mgmt + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + "time" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/shared/management/domain" +) + +const dnsTimeout = 5 * time.Second + +// Resolver caches critical NetBird infrastructure domains +type Resolver struct { + records map[dns.Question][]dns.RR + mgmtDomain *domain.Domain + serverDomains *dnsconfig.ServerDomains + mutex sync.RWMutex +} + +// NewResolver creates a new management domains cache resolver. +func NewResolver() *Resolver { + return &Resolver{ + records: make(map[dns.Question][]dns.RR), + } +} + +// String returns a string representation of the resolver. +func (m *Resolver) String() string { + return "MgmtCacheResolver" +} + +// ServeDNS implements dns.Handler interface. +func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + m.continueToNext(w, r) + return + } + + question := r.Question[0] + question.Name = strings.ToLower(dns.Fqdn(question.Name)) + + if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA { + m.continueToNext(w, r) + return + } + + m.mutex.RLock() + records, found := m.records[question] + m.mutex.RUnlock() + + if !found { + m.continueToNext(w, r) + return + } + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + + resp.Answer = append(resp.Answer, records...) + + log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write response: %v", err) + } +} + +// MatchSubdomains returns false since this resolver only handles exact domain matches +// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. +func (m *Resolver) MatchSubdomains() bool { + return false +} + +// continueToNext signals the handler chain to continue to the next handler. +func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write continue signal: %v", err) + } +} + +// AddDomain manually adds a domain to cache by resolving it. +func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + ctx, cancel := context.WithTimeout(ctx, dnsTimeout) + defer cancel() + + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + if err != nil { + return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err) + } + + var aRecords, aaaaRecords []dns.RR + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) + } + } + + m.mutex.Lock() + + if len(aRecords) > 0 { + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + m.records[aQuestion] = aRecords + } + + if len(aaaaRecords) > 0 { + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + m.records[aaaaQuestion] = aaaaRecords + } + + m.mutex.Unlock() + + log.Debugf("added domain=%s with %d A records and %d AAAA records", + d.SafeString(), len(aRecords), len(aaaaRecords)) + + return nil +} + +// PopulateFromConfig extracts and caches domains from the client configuration. +func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error { + if mgmtURL == nil { + return nil + } + + d, err := dnsconfig.ExtractValidDomain(mgmtURL.String()) + if err != nil { + return fmt.Errorf("extract domain from URL: %w", err) + } + + m.mutex.Lock() + m.mgmtDomain = &d + m.mutex.Unlock() + + if err := m.AddDomain(ctx, d); err != nil { + return fmt.Errorf("add domain: %w", err) + } + + return nil +} + +// RemoveDomain removes a domain from the cache. +func (m *Resolver) RemoveDomain(d domain.Domain) error { + dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) + + m.mutex.Lock() + defer m.mutex.Unlock() + + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + delete(m.records, aQuestion) + + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + delete(m.records, aaaaQuestion) + + log.Debugf("removed domain=%s from cache", d.SafeString()) + return nil +} + +// GetCachedDomains returns a list of all cached domains. +func (m *Resolver) GetCachedDomains() domain.List { + m.mutex.RLock() + defer m.mutex.RUnlock() + + domainSet := make(map[domain.Domain]struct{}) + for question := range m.records { + domainName := strings.TrimSuffix(question.Name, ".") + domainSet[domain.Domain(domainName)] = struct{}{} + } + + domains := make(domain.List, 0, len(domainSet)) + for d := range domainSet { + domains = append(domains, d) + } + + return domains +} + +// UpdateFromServerDomains updates the cache with server domains from network configuration. +// It merges new domains with existing ones, replacing entire domain types when updated. +// Empty updates are ignored to prevent clearing infrastructure domains during partial updates. +func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) { + newDomains := m.extractDomainsFromServerDomains(serverDomains) + var removedDomains domain.List + + if len(newDomains) > 0 { + m.mutex.Lock() + if m.serverDomains == nil { + m.serverDomains = &dnsconfig.ServerDomains{} + } + updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains) + m.serverDomains = &updatedServerDomains + m.mutex.Unlock() + + allDomains := m.extractDomainsFromServerDomains(updatedServerDomains) + currentDomains := m.GetCachedDomains() + removedDomains = m.removeStaleDomains(currentDomains, allDomains) + } + + m.addNewDomains(ctx, newDomains) + + return removedDomains, nil +} + +// removeStaleDomains removes cached domains not present in the target domain list. +// Management domains are preserved and never removed during server domain updates. +func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List { + var removedDomains domain.List + + for _, currentDomain := range currentDomains { + if m.isDomainInList(currentDomain, newDomains) { + continue + } + + if m.isManagementDomain(currentDomain) { + continue + } + + removedDomains = append(removedDomains, currentDomain) + if err := m.RemoveDomain(currentDomain); err != nil { + log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err) + } + } + + return removedDomains +} + +// mergeServerDomains merges new server domains with existing ones. +// When a domain type is provided in the new domains, it completely replaces that type. +func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains { + merged := existing + + if incoming.Signal != "" { + merged.Signal = incoming.Signal + } + if len(incoming.Relay) > 0 { + merged.Relay = incoming.Relay + } + if incoming.Flow != "" { + merged.Flow = incoming.Flow + } + if len(incoming.Stuns) > 0 { + merged.Stuns = incoming.Stuns + } + if len(incoming.Turns) > 0 { + merged.Turns = incoming.Turns + } + + return merged +} + +// isDomainInList checks if domain exists in the list +func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool { + for _, d := range list { + if domain.SafeString() == d.SafeString() { + return true + } + } + return false +} + +// isManagementDomain checks if domain is the protected management domain +func (m *Resolver) isManagementDomain(domain domain.Domain) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return m.mgmtDomain != nil && domain == *m.mgmtDomain +} + +// addNewDomains resolves and caches all domains from the update +func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) { + for _, newDomain := range newDomains { + if err := m.AddDomain(ctx, newDomain); err != nil { + log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err) + } else { + log.Debugf("added/updated management cache domain=%s", newDomain.SafeString()) + } + } +} + +func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List { + var domains domain.List + + if serverDomains.Signal != "" { + domains = append(domains, serverDomains.Signal) + } + + for _, relay := range serverDomains.Relay { + if relay != "" { + domains = append(domains, relay) + } + } + + if serverDomains.Flow != "" { + domains = append(domains, serverDomains.Flow) + } + + for _, stun := range serverDomains.Stuns { + if stun != "" { + domains = append(domains, stun) + } + } + + for _, turn := range serverDomains.Turns { + if turn != "" { + domains = append(domains, turn) + } + } + + return domains +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go new file mode 100644 index 000000000..99d289871 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -0,0 +1,416 @@ +package mgmt + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +func TestResolver_NewResolver(t *testing.T) { + resolver := NewResolver() + + assert.NotNil(t, resolver) + assert.NotNil(t, resolver.records) + assert.False(t, resolver.MatchSubdomains()) +} + +func TestResolver_ExtractDomainFromURL(t *testing.T) { + tests := []struct { + name string + urlStr string + expectedDom string + expectError bool + }{ + { + name: "HTTPS URL with port", + urlStr: "https://api.netbird.io:443", + expectedDom: "api.netbird.io", + expectError: false, + }, + { + name: "HTTP URL without port", + urlStr: "http://signal.example.com", + expectedDom: "signal.example.com", + expectError: false, + }, + { + name: "URL with path", + urlStr: "https://relay.netbird.io/status", + expectedDom: "relay.netbird.io", + expectError: false, + }, + { + name: "Invalid URL", + urlStr: "not-a-valid-url", + expectedDom: "not-a-valid-url", + expectError: false, + }, + { + name: "Empty URL", + urlStr: "", + expectedDom: "", + expectError: true, + }, + { + name: "STUN URL", + urlStr: "stun:stun.example.com:3478", + expectedDom: "stun.example.com", + expectError: false, + }, + { + name: "TURN URL", + urlStr: "turn:turn.example.com:3478", + expectedDom: "turn.example.com", + expectError: false, + }, + { + name: "REL URL", + urlStr: "rel://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + { + name: "RELS URL", + urlStr: "rels://relay.example.com:443", + expectedDom: "relay.example.com", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var parsedURL *url.URL + var err error + + if tt.urlStr != "" { + parsedURL, err = url.Parse(tt.urlStr) + if err != nil && !tt.expectError { + t.Fatalf("Failed to parse URL: %v", err) + } + } + + domain, err := extractDomainFromURL(parsedURL) + + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedDom, domain.SafeString()) + } + }) + } +} + +func TestResolver_PopulateFromConfig(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := NewResolver() + + // Test with IP address - should return error since IP addresses are rejected + mgmtURL, _ := url.Parse("https://127.0.0.1") + + err := resolver.PopulateFromConfig(ctx, mgmtURL) + assert.Error(t, err) + assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed) + + // No domains should be cached when using IP addresses + domains := resolver.GetCachedDomains() + assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses") +} + +func TestResolver_ServeDNS(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Add a test domain to the cache - use example.org which is reserved for testing + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Test A record query for cached domain + t.Run("Cached domain A record", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer") + }) + + // Test uncached domain signals to continue to next handler + t.Run("Uncached domain signals continue to next handler", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + req := new(dns.Msg) + req.SetQuestion("unknown.example.com.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + // Zero flag set to true signals the handler chain to continue to next handler + assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain") + }) + + // Test that subdomains of cached domains are NOT resolved + t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query for a subdomain of our cached domain + req := new(dns.Msg) + req.SetQuestion("sub.example.org.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode) + assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains") + assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains") + }) + + // Test case-insensitive matching + t.Run("Case-insensitive domain matching", func(t *testing.T) { + var capturedMsg *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + capturedMsg = m + return nil + }, + } + + // Query with different casing + req := new(dns.Msg) + req.SetQuestion("EXAMPLE.ORG.", dns.TypeA) + + resolver.ServeDNS(mockWriter, req) + + assert.NotNil(t, capturedMsg) + assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode) + assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case") + }) +} + +func TestResolver_GetCachedDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + testDomain, err := domain.FromString("example.org") + if err != nil { + t.Fatalf("Failed to create domain: %v", err) + } + err = resolver.AddDomain(ctx, testDomain) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + cachedDomains := resolver.GetCachedDomains() + + assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain") + assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original") + assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot") +} + +func TestResolver_ManagementDomainProtection(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + mgmtURL, _ := url.Parse("https://example.org") + err := resolver.PopulateFromConfig(ctx, mgmtURL) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + initialDomains := resolver.GetCachedDomains() + if len(initialDomains) == 0 { + t.Skip("Management domain failed to resolve, skipping test") + } + assert.Equal(t, 1, len(initialDomains), "Should have management domain cached") + assert.Equal(t, "example.org", initialDomains[0].SafeString()) + + serverDomains := dnsconfig.ServerDomains{ + Signal: "google.com", + Relay: []domain.Domain{"cloudflare.com"}, + } + + _, err = resolver.UpdateFromServerDomains(ctx, serverDomains) + if err != nil { + t.Logf("Server domains update failed: %v", err) + } + + finalDomains := resolver.GetCachedDomains() + + managementStillCached := false + for _, d := range finalDomains { + if d.SafeString() == "example.org" { + managementStillCached = true + break + } + } + assert.True(t, managementStillCached, "Management domain should never be removed") +} + +// extractDomainFromURL extracts a domain from a URL - test helper function +func extractDomainFromURL(u *url.URL) (domain.Domain, error) { + if u == nil { + return "", fmt.Errorf("URL is nil") + } + return dnsconfig.ExtractValidDomain(u.String()) +} + +func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Verify domains were added + cachedDomains := resolver.GetCachedDomains() + assert.Len(t, cachedDomains, 3) + + // Update with empty ServerDomains (simulating partial network map update) + emptyDomains := dnsconfig.ServerDomains{} + removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains) + assert.NoError(t, err) + + // Verify no domains were removed + assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty") + + // Verify all original domains are still cached + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "All original domains should still be cached") +} + +func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn) + partialDomains := dnsconfig.ServerDomains{ + Signal: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + // Should remove only the old signal domain + assert.Len(t, removedDomains, 1, "Should remove only the old signal domain") + assert.Equal(t, "example.org", removedDomains[0].SafeString()) + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "github.com") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.NotContains(t, domainStrings, "example.org") +} + +func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) { + resolver := NewResolver() + ctx := context.Background() + + // Set up initial complete domains using resolvable domains + initialDomains := dnsconfig.ServerDomains{ + Signal: "example.org", + Stuns: []domain.Domain{"google.com"}, + Turns: []domain.Domain{"cloudflare.com"}, + } + + // Add initial domains + _, err := resolver.UpdateFromServerDomains(ctx, initialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + assert.Len(t, resolver.GetCachedDomains(), 3) + + // Update with partial ServerDomains (only flow domain - new type, should preserve all existing) + partialDomains := dnsconfig.ServerDomains{ + Flow: "github.com", + } + removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains) + if err != nil { + t.Skipf("Skipping test due to DNS resolution failure: %v", err) + } + + assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type") + + finalDomains := resolver.GetCachedDomains() + assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain") + + domainStrings := make([]string, len(finalDomains)) + for i, d := range finalDomains { + domainStrings[i] = d.SafeString() + } + assert.Contains(t, domainStrings, "example.org") + assert.Contains(t, domainStrings, "google.com") + assert.Contains(t, domainStrings, "cloudflare.com") + assert.Contains(t, domainStrings, "github.com") +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index d160fa99a..0f89b9016 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,20 +3,23 @@ package dns import ( "fmt" "net/netip" + "net/url" "github.com/miekg/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/shared/management/domain" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error - RegisterHandlerFunc func(domain.List, dns.Handler, int) - DeregisterHandlerFunc func(domain.List, int) + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func(domain.List, dns.Handler, int) + DeregisterHandlerFunc func(domain.List, int) + UpdateServerConfigFunc func(domains dnsconfig.ServerDomains) error } func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) { @@ -70,3 +73,14 @@ func (m *MockServer) SearchDomains() []string { // ProbeAvailability mocks implementation of ProbeAvailability from the Server interface func (m *MockServer) ProbeAvailability() { } + +func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + if m.UpdateServerConfigFunc != nil { + return m.UpdateServerConfigFunc(domains) + } + return nil +} + +func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { + return nil +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index cbcf6a256..8cb886203 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "net/url" "runtime" "strings" "sync" @@ -15,7 +16,9 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/iface/netstack" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dns/local" + "github.com/netbirdio/netbird/client/internal/dns/mgmt" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -45,6 +48,8 @@ type Server interface { OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string ProbeAvailability() + UpdateServerConfig(domains dnsconfig.ServerDomains) error + PopulateManagementDomain(mgmtURL *url.URL) error } type nsGroupsByDomain struct { @@ -77,6 +82,8 @@ type DefaultServer struct { handlerChain *HandlerChain extraDomains map[domain.Domain]int + mgmtCacheResolver *mgmt.Resolver + // permanent related properties permanent bool hostsDNSHolder *hostsDNSHolder @@ -104,18 +111,20 @@ type handlerWrapper struct { type registeredHandlerMap map[types.HandlerID]handlerWrapper +// DefaultServerConfig holds configuration parameters for NewDefaultServer +type DefaultServerConfig struct { + WgInterface WGIface + CustomAddress string + StatusRecorder *peer.Status + StateManager *statemanager.Manager + DisableSys bool +} + // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, - stateManager *statemanager.Manager, - disableSys bool, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*DefaultServer, error) { var addrPort *netip.AddrPort - if customAddress != "" { - parsedAddrPort, err := netip.ParseAddrPort(customAddress) + if config.CustomAddress != "" { + parsedAddrPort, err := netip.ParseAddrPort(config.CustomAddress) if err != nil { return nil, fmt.Errorf("unable to parse the custom dns address, got error: %s", err) } @@ -123,13 +132,14 @@ func NewDefaultServer( } var dnsService service - if wgInterface.IsUserspaceBind() { - dnsService = NewServiceViaMemory(wgInterface) + if config.WgInterface.IsUserspaceBind() { + dnsService = NewServiceViaMemory(config.WgInterface) } else { - dnsService = newServiceViaListener(wgInterface, addrPort) + dnsService = newServiceViaListener(config.WgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil + server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys) + return server, nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -178,20 +188,24 @@ func newDefaultServer( ) *DefaultServer { handlerChain := NewHandlerChain() ctx, stop := context.WithCancel(ctx) + + mgmtCacheResolver := mgmt.NewResolver() + defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - disableSys: disableSys, - service: dnsService, - handlerChain: handlerChain, - extraDomains: make(map[domain.Domain]int), - dnsMuxMap: make(registeredHandlerMap), - localResolver: local.NewResolver(), - wgInterface: wgInterface, - statusRecorder: statusRecorder, - stateManager: stateManager, - hostsDNSHolder: newHostsDNSHolder(), - hostManager: &noopHostConfigurator{}, + ctx: ctx, + ctxCancel: stop, + disableSys: disableSys, + service: dnsService, + handlerChain: handlerChain, + extraDomains: make(map[domain.Domain]int), + dnsMuxMap: make(registeredHandlerMap), + localResolver: local.NewResolver(), + wgInterface: wgInterface, + statusRecorder: statusRecorder, + stateManager: stateManager, + hostsDNSHolder: newHostsDNSHolder(), + hostManager: &noopHostConfigurator{}, + mgmtCacheResolver: mgmtCacheResolver, } // register with root zone, handler chain takes care of the routing @@ -217,7 +231,7 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { - log.Debugf("registering handler %s with priority %d", handler, priority) + log.Debugf("registering handler %s with priority %d for %v", handler, priority, domains) for _, domain := range domains { if domain == "" { @@ -246,7 +260,7 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { } func (s *DefaultServer) deregisterHandler(domains []string, priority int) { - log.Debugf("deregistering handler %v with priority %d", domains, priority) + log.Debugf("deregistering handler with priority %d for %v", priority, domains) for _, domain := range domains { if domain == "" { @@ -432,6 +446,29 @@ func (s *DefaultServer) ProbeAvailability() { wg.Wait() } +func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { + s.mux.Lock() + defer s.mux.Unlock() + + if s.mgmtCacheResolver != nil { + removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains) + if err != nil { + return fmt.Errorf("update management cache resolver: %w", err) + } + + if len(removedDomains) > 0 { + s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) + } + + newDomains := s.mgmtCacheResolver.GetCachedDomains() + if len(newDomains) > 0 { + s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) + } + } + + return nil +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver if update.ServiceEnable { @@ -961,3 +998,11 @@ func toZone(d domain.Domain) domain.Domain { ), ) } + +// PopulateManagementDomain populates the DNS cache with management domain +func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { + if s.mgmtCacheResolver != nil { + return s.mgmtCacheResolver.PopulateFromConfig(s.ctx, mgmtURL) + } + return nil +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 068f001d8..11575d500 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -363,7 +363,13 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatal(err) } @@ -473,7 +479,13 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: wgIface, + CustomAddress: "", + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -575,7 +587,13 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, peer.NewRecorder("mgm"), nil, false) + dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{ + WgInterface: &mocWGIface{}, + CustomAddress: testCase.addrPort, + StatusRecorder: peer.NewRecorder("mgm"), + StateManager: nil, + DisableSys: false, + }) if err != nil { t.Fatalf("%v", err) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 071e3617a..c19e0acb5 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) { } const ( - UpstreamTimeout = 15 * time.Second + UpstreamTimeout = 4 * time.Second + // ClientTimeout is the timeout for the dns.Client. + // Set longer than UpstreamTimeout to ensure context timeout takes precedence + ClientTimeout = 5 * time.Second - failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second probeTimeout = 2 * time.Second ) @@ -58,9 +60,7 @@ type upstreamResolverBase struct { upstreamServers []netip.AddrPort domain string disabled bool - failsCount atomic.Int32 successCount atomic.Int32 - failsTillDeact int32 mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration @@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain: domain, upstreamTimeout: UpstreamTimeout, reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.upstreamServers) } // ID returns the unique handler ID @@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { requestID := GenerateRequestID() logger := log.WithField("request_id", requestID) - var err error - defer func() { - u.checkUpstreamFails(err) - }() logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + u.prepareRequest(r) + + if u.ctx.Err() != nil { + logger.Tracef("%s has been stopped", u) + return + } + + if u.tryUpstreamServers(w, r, logger) { + return + } + + u.writeErrorResponse(w, r, logger) +} + +func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } +} - select { - case <-u.ctx.Done(): - logger.Tracef("%s has been stopped", u) - return - default: +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { + timeout := u.upstreamTimeout + if len(u.upstreamServers) > 1 { + maxTotal := 5 * time.Second + minPerUpstream := 2 * time.Second + scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) + if scaledTimeout > minPerUpstream { + timeout = scaledTimeout + } else { + timeout = minPerUpstream + } } for _, upstream := range u.upstreamServers { - var rm *dns.Msg - var t time.Duration - - func() { - ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - defer cancel() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) - }() - - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) - continue - } - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) - continue + if u.queryUpstream(w, r, upstream, timeout, logger) { + return true } + } + return false +} - if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - continue - } +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { + var rm *dns.Msg + var t time.Duration + var err error - u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + var startTime time.Time + func() { + ctx, cancel := context.WithTimeout(u.ctx, timeout) + defer cancel() + startTime = time.Now() + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + }() - if err = w.WriteMsg(rm); err != nil { - logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) - } - // count the fails only if they happen sequentially - u.failsCount.Store(0) + if err != nil { + u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) + return false + } + + if rm == nil || !rm.Response { + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + return false + } + + return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) +} + +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { + if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) return } - u.failsCount.Add(1) + + elapsed := time.Since(startTime) + timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { + timeoutMsg += " " + peerInfo + } + timeoutMsg += fmt.Sprintf(" - error: %v", err) + logger.Warnf(timeoutMsg) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { + u.successCount.Add(1) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) + + if err := w.WriteMsg(rm); err != nil { + logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) + } + return true +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) @@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -// checkUpstreamFails counts fails and disables or enables upstream resolving -// -// If fails count is greater that failsTillDeact, upstream resolving -// will be disabled for reactivatePeriod, after that time period fails counter -// will be reset and upstream will be reactivated. -func (u *upstreamResolverBase) checkUpstreamFails(err error) { - u.mutex.Lock() - defer u.mutex.Unlock() - - if u.failsCount.Load() < u.failsTillDeact || u.disabled { - return - } - - select { - case <-u.ctx.Done(): - return - default: - } - - u.disable(err) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (fail count exceeded)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - // TODO add domain meta - ) -} - // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() { default: } - // avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact - if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact { + // avoid probe if upstreams could resolve at least one query + if u.successCount.Load() > 0 { return } @@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() { } log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.failsCount.Store(0) u.successCount.Add(1) u.reactivate() u.disabled = false @@ -416,3 +423,80 @@ func GenerateRequestID() string { } return hex.EncodeToString(bytes) } + +// FormatPeerStatus formats peer connection status information for debugging DNS timeouts +func FormatPeerStatus(peerState *peer.State) string { + isConnected := peerState.ConnStatus == peer.StatusConnected + hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() && + time.Since(peerState.LastWireguardHandshake) < 3*time.Minute + + statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP) + + switch { + case !isConnected: + statusInfo += " DISCONNECTED" + case !hasRecentHandshake: + statusInfo += " NO_RECENT_HANDSHAKE" + default: + statusInfo += " connected" + } + + if !peerState.LastWireguardHandshake.IsZero() { + timeSinceHandshake := time.Since(peerState.LastWireguardHandshake) + statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second)) + } else { + statusInfo += " no_handshake" + } + + if peerState.Relayed { + statusInfo += " via_relay" + } + + if peerState.Latency > 0 { + statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency) + } + + return statusInfo +} + +// findPeerForIP finds which peer handles the given IP address +func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { + if statusRecorder == nil { + return nil + } + + fullStatus := statusRecorder.GetFullStatus() + var bestMatch *peer.State + var bestPrefixLen int + + for _, peerState := range fullStatus.Peers { + routes := peerState.GetRoutes() + for route := range routes { + prefix, err := netip.ParsePrefix(route) + if err != nil { + continue + } + + if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen { + peerStateCopy := peerState + bestMatch = &peerStateCopy + bestPrefixLen = prefix.Bits() + } + } + } + + return bestMatch +} + +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ddbf84ae4..6b7dcc05e 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -50,7 +50,9 @@ func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns } func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{} + upstreamExchangeClient := &dns.Client{ + Timeout: ClientTimeout, + } return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } @@ -72,10 +74,11 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri } upstreamExchangeClient := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: timeout, } - return upstreamExchangeClient.Exchange(r, upstream) + return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } func (u *upstreamResolver) isLocalResolver(upstream string) bool { diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 317588a27..434e5880b 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -34,7 +34,10 @@ func newUpstreamResolver( } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) + client := &dns.Client{ + Timeout: ClientTimeout, + } + return ExchangeWithFallback(ctx, client, r, upstream) } func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 96b8bbb0f..eadcdd117 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -47,7 +47,9 @@ func newUpstreamResolver( } func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - client := &dns.Client{} + client := &dns.Client{ + Timeout: ClientTimeout, + } upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err) @@ -110,7 +112,8 @@ func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Dura }, } client := &dns.Client{ - Dialer: dialer, + Dialer: dialer, + Timeout: dialTimeout, } return client, nil } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 51d870e2a..e1573e75e 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -124,29 +124,26 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) } func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { + mockClient := &mockUpstreamResolver{ + err: dns.ErrTime, + r: new(dns.Msg), + rtt: time.Millisecond, + } + resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: &mockUpstreamResolver{ - err: nil, - r: new(dns.Msg), - rtt: time.Millisecond, - }, + ctx: context.TODO(), + upstreamClient: mockClient, upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - failsTillDeact: failsTillDeact, + reactivatePeriod: time.Microsecond * 100, } addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - resolver.failsTillDeact = 0 - resolver.reactivatePeriod = time.Microsecond * 100 - - responseWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { return nil }, - } failed := false resolver.deactivate = func(error) { failed = true + // After deactivation, make the mock client work again + mockClient.err = nil } reactivated := false @@ -154,7 +151,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { reactivated = true } - resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) + resolver.ProbeAvailability() if !failed { t.Errorf("expected that resolving was deactivated") @@ -173,11 +170,6 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { return } - if resolver.failsCount.Load() != 0 { - t.Errorf("fails count after reactivation should be 0") - return - } - if resolver.disabled { t.Errorf("should be enabled") } diff --git a/client/internal/engine.go b/client/internal/engine.go index c1c67207b..c610de41e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -7,6 +7,7 @@ import ( "math/rand" "net" "net/netip" + "net/url" "os" "reflect" "runtime" @@ -33,6 +34,7 @@ import ( nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/netflow" @@ -345,7 +347,7 @@ func (e *Engine) Stop() error { // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service -func (e *Engine) Start() error { +func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -401,6 +403,11 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer + // Populate DNS cache with NetbirdConfig and management URL for early resolution + if err := e.PopulateNetbirdConfig(netbirdConfig, mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache: %v", err) + } + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ Context: e.ctx, PublicKey: e.config.WgPrivateKey.PublicKey().String(), @@ -661,6 +668,30 @@ func (e *Engine) removePeer(peerKey string) error { return nil } +// PopulateNetbirdConfig populates the DNS cache with infrastructure domains from login response +func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { + if e.dnsServer == nil { + return nil + } + + // Populate management URL if provided + if mgmtURL != nil { + if err := e.dnsServer.PopulateManagementDomain(mgmtURL); err != nil { + log.Warnf("failed to populate DNS cache with management URL: %v", err) + } + } + + // Populate NetbirdConfig domains if provided + if netbirdConfig != nil { + serverDomains := dnsconfig.ExtractFromNetbirdConfig(netbirdConfig) + if err := e.dnsServer.UpdateServerConfig(serverDomains); err != nil { + return fmt.Errorf("update DNS server config from NetbirdConfig: %w", err) + } + } + + return nil +} + func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -692,6 +723,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("handle the flow configuration: %w", err) } + if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil { + log.Warnf("Failed to update DNS server config: %v", err) + } + // todo update signal } @@ -1557,7 +1592,14 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) { return dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) + + dnsServer, err := dns.NewDefaultServer(e.ctx, dns.DefaultServerConfig{ + WgInterface: e.wgInterface, + CustomAddress: e.config.CustomDNSAddress, + StatusRecorder: e.statusRecorder, + StateManager: e.stateManager, + DisableSys: e.config.DisableDNS, + }) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9d6361bc0..fc58dbdba 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -266,7 +266,7 @@ func TestEngine_SSH(t *testing.T) { }, }, nil } - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) } @@ -612,7 +612,7 @@ func TestEngine_Sync(t *testing.T) { } }() - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Fatal(err) return @@ -1069,7 +1069,7 @@ func TestEngine_MultiplePeers(t *testing.T) { defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - err = engine.Start() + err = engine.Start(nil, nil) if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) wg.Done() diff --git a/client/internal/login.go b/client/internal/login.go index d5412a110..257e3c3ac 100644 --- a/client/internal/login.go +++ b/client/internal/login.go @@ -40,7 +40,7 @@ func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, return false, err } - _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) + _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) if isLoginNeeded(err) { return true, nil } @@ -69,14 +69,18 @@ func Login(ctx context.Context, config *profilemanager.Config, setupKey string, return err } - serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) + serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) if serverKey != nil && isRegistrationNeeded(err) { log.Debugf("peer registration required") _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) + if err != nil { + return err + } + } else if err != nil { return err } - return err + return nil } func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { @@ -101,11 +105,11 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm return mgmClient, err } -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, error) { +func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { serverKey, err := mgmClient.GetServerPublicKey() if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) - return nil, err + return nil, nil, err } sysInfo := system.GetInfo(ctx) @@ -121,8 +125,8 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte config.BlockInbound, config.LazyConnectionEnabled, ) - _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, err + loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) + return serverKey, loginResp, err } // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index ba27df654..9069cdcc5 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -2,11 +2,13 @@ package dnsinterceptor import ( "context" + "errors" "fmt" "net/netip" "runtime" "strings" "sync" + "time" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" @@ -26,6 +28,8 @@ import ( "github.com/netbirdio/netbird/route" ) +const dnsTimeout = 8 * time.Second + type domainMap map[domain.Domain][]netip.Prefix type internalDNATer interface { @@ -243,7 +247,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) + client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) if err != nil { d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return @@ -254,9 +258,20 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) - reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + startTime := time.Now() + reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if err != nil { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + if errors.Is(err, context.DeadlineExceeded) { + elapsed := time.Since(startTime) + peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) + logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", + elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) + } else { + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + } if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { logger.Errorf("failed writing DNS response: %v", err) } @@ -568,3 +583,16 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR } return } + +func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { + if d.statusRecorder == nil { + return "" + } + + peerState, err := d.statusRecorder.GetPeer(peerKey) + if err != nil { + return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err) + } + + return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState)) +} From e14c6de20323cef9dacc2b3d6aac72a9e19c3df3 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 29 Aug 2025 17:41:20 +0200 Subject: [PATCH 08/41] [management] fix ephemeral flag on peer batch response (#4420) --- .../http/handlers/peers/peers_handler.go | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 414c7b1b9..af501e151 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -14,11 +14,11 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // Handler is a handler that returns peers of the account @@ -354,7 +354,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - CreatedAt: peer.CreatedAt, + CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, Ip: peer.IP.String(), @@ -391,33 +391,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn } return &api.PeerBatch{ - CreatedAt: peer.CreatedAt, - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.GetLastLogin(), - LoginExpired: peer.Status.LoginExpired, - AccessiblePeersCount: accessiblePeersCount, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, - + CreatedAt: peer.CreatedAt, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.GetLastLogin(), + LoginExpired: peer.Status.LoginExpired, + AccessiblePeersCount: accessiblePeersCount, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, InactivityExpirationEnabled: peer.InactivityExpirationEnabled, + Ephemeral: peer.Ephemeral, } } From 149559a06b1485f0b88ff881c446a40eed26feec Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 29 Aug 2025 19:48:40 +0200 Subject: [PATCH 09/41] [management] login filter to fix multiple peers connected with the same pub key (#3986) --- management/server/account.go | 11 +- management/server/account/manager.go | 1 + management/server/grpcserver.go | 106 +++++-- management/server/loginfilter.go | 160 ++++++++++ management/server/loginfilter_test.go | 275 ++++++++++++++++++ management/server/mock_server/account_mock.go | 13 +- management/server/peer_test.go | 1 - management/server/telemetry/grpc_metrics.go | 30 ++ shared/management/status/error.go | 5 +- 9 files changed, 568 insertions(+), 34 deletions(-) create mode 100644 management/server/loginfilter.go create mode 100644 management/server/loginfilter_test.go diff --git a/management/server/account.go b/management/server/account.go index f217eadb3..b57550f14 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -104,6 +104,8 @@ type DefaultAccountManager struct { accountUpdateLocks sync.Map updateAccountPeersBufferInterval atomic.Int64 + loginFilter *loginFilter + disableDefaultPolicy bool } @@ -211,6 +213,7 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, + loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, } @@ -1612,6 +1615,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } +func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { + return am.loginFilter.allowLogin(wgPubKey, metahash) +} + func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { start := time.Now() defer func() { @@ -1628,6 +1635,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } + metahash := metaHash(meta, realIP.String()) + am.loginFilter.addLogin(peerPubKey, metahash) + return peer, netMap, postureChecks, nil } @@ -1636,7 +1646,6 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } - return nil } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index c7a39004a..c7329a1da 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -123,4 +123,5 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + AllowSync(string, uint64) bool } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 41d463ae3..95d4f9eee 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -2,9 +2,11 @@ package server import ( "context" + "errors" "fmt" "net" "net/netip" + "os" "strings" "sync" "time" @@ -38,20 +40,28 @@ import ( internalStatus "github.com/netbirdio/netbird/shared/management/status" ) +const ( + envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" + envBlockPeers = "NB_BLOCK_SAME_PEERS" +) + // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager - peerLocks sync.Map - authManager auth.Manager - integratedPeerValidator integrated_validator.IntegratedValidator + peersUpdateManager *PeersUpdateManager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager + peerLocks sync.Map + authManager auth.Manager + + logBlockedPeers bool + blockPeersWithSameConfig bool + integratedPeerValidator integrated_validator.IntegratedValidator } // NewServer creates a new Management server @@ -82,18 +92,23 @@ func NewServer( } } + logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" + blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + return &GRPCServer{ wgKey: key, // peerKey -> event channel - peersUpdateManager: peersUpdateManager, - accountManager: accountManager, - settingsManager: settingsManager, - config: config, - secretsManager: secretsManager, - authManager: authManager, - appMetrics: appMetrics, - ephemeralManager: ephemeralManager, - integratedPeerValidator: integratedPeerValidator, + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + settingsManager: settingsManager, + config: config, + secretsManager: secretsManager, + authManager: authManager, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + logBlockedPeers: logBlockedPeers, + blockPeersWithSameConfig: blockPeersWithSameConfig, + integratedPeerValidator: integratedPeerValidator, }, nil } @@ -136,9 +151,6 @@ func getRealIP(ctx context.Context) net.IP { // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { reqStart := time.Now() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountSyncRequest() - } ctx := srv.Context() @@ -147,6 +159,25 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if err != nil { return err } + realIP := getRealIP(ctx) + sRealIP := realIP.String() + peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) + metahashed := metaHash(peerMeta, sRealIP) + if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() + } + if s.logBlockedPeers { + log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) + } + if s.blockPeersWithSameConfig { + return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) + } + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequest() + } // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) @@ -172,14 +203,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - realIP := getRealIP(ctx) - log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) return mapError(ctx, err) @@ -345,6 +375,9 @@ func mapError(ctx context.Context, err error) error { default: } } + if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) { + return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error()) + } log.WithContext(ctx).Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } @@ -436,12 +469,9 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() - - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequest() - } realIP := getRealIP(ctx) - log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + sRealIP := realIP.String() + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP) loginReq := &proto.LoginRequest{} peerKey, err := s.parseRequest(ctx, req, loginReq) @@ -449,6 +479,24 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, err } + peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) + metahashed := metaHash(peerMeta, sRealIP) + if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if s.logBlockedPeers { + log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) + } + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestBlocked() + } + if s.blockPeersWithSameConfig { + return nil, internalStatus.ErrPeerAlreadyLoggedIn + } + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + //nolint ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) @@ -485,7 +533,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), - Meta: extractPeerMeta(ctx, loginReq.GetMeta()), + Meta: peerMeta, UserID: userID, SetupKey: loginReq.GetSetupKey(), ConnectionIP: realIP, diff --git a/management/server/loginfilter.go b/management/server/loginfilter.go new file mode 100644 index 000000000..8604af6e2 --- /dev/null +++ b/management/server/loginfilter.go @@ -0,0 +1,160 @@ +package server + +import ( + "hash/fnv" + "math" + "sync" + "time" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +const ( + reconnThreshold = 5 * time.Minute + baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit + reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban + metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer +) + +type lfConfig struct { + reconnThreshold time.Duration + baseBlockDuration time.Duration + reconnLimitForBan int + metaChangeLimit int +} + +func initCfg() *lfConfig { + return &lfConfig{ + reconnThreshold: reconnThreshold, + baseBlockDuration: baseBlockDuration, + reconnLimitForBan: reconnLimitForBan, + metaChangeLimit: metaChangeLimit, + } +} + +type loginFilter struct { + mu sync.RWMutex + cfg *lfConfig + logged map[string]*peerState +} + +type peerState struct { + currentHash uint64 + sessionCounter int + sessionStart time.Time + lastSeen time.Time + isBanned bool + banLevel int + banExpiresAt time.Time + metaChangeCounter int + metaChangeWindowStart time.Time +} + +func newLoginFilter() *loginFilter { + return newLoginFilterWithCfg(initCfg()) +} + +func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter { + return &loginFilter{ + logged: make(map[string]*peerState), + cfg: cfg, + } +} + +func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool { + l.mu.RLock() + defer func() { + l.mu.RUnlock() + }() + state, ok := l.logged[wgPubKey] + if !ok { + return true + } + if state.isBanned && time.Now().Before(state.banExpiresAt) { + return false + } + if metaHash != state.currentHash { + if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit { + return false + } + } + return true +} + +func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) { + now := time.Now() + l.mu.Lock() + defer func() { + l.mu.Unlock() + }() + + state, ok := l.logged[wgPubKey] + + if !ok { + l.logged[wgPubKey] = &peerState{ + currentHash: metaHash, + sessionCounter: 1, + sessionStart: now, + lastSeen: now, + metaChangeWindowStart: now, + metaChangeCounter: 1, + } + return + } + + if state.isBanned && now.After(state.banExpiresAt) { + state.isBanned = false + } + + if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) { + state.banLevel = 0 + } + + if metaHash != state.currentHash { + if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) { + state.metaChangeWindowStart = now + state.metaChangeCounter = 1 + } else { + state.metaChangeCounter++ + } + state.currentHash = metaHash + state.sessionCounter = 1 + state.sessionStart = now + state.lastSeen = now + return + } + + state.sessionCounter++ + if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold { + state.isBanned = true + state.banLevel++ + + backoffFactor := math.Pow(2, float64(state.banLevel-1)) + duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor) + state.banExpiresAt = now.Add(duration) + + state.sessionCounter = 0 + state.sessionStart = now + } + state.lastSeen = now +} + +func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 { + h := fnv.New64a() + + h.Write([]byte(meta.WtVersion)) + h.Write([]byte(meta.OSVersion)) + h.Write([]byte(meta.KernelVersion)) + h.Write([]byte(meta.Hostname)) + h.Write([]byte(meta.SystemSerialNumber)) + h.Write([]byte(pubip)) + + macs := uint64(0) + for _, na := range meta.NetworkAddresses { + for _, r := range na.Mac { + macs += uint64(r) + } + } + + return h.Sum64() + macs +} diff --git a/management/server/loginfilter_test.go b/management/server/loginfilter_test.go new file mode 100644 index 000000000..65782dd9d --- /dev/null +++ b/management/server/loginfilter_test.go @@ -0,0 +1,275 @@ +package server + +import ( + "hash/fnv" + "math" + "math/rand" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func testAdvancedCfg() *lfConfig { + return &lfConfig{ + reconnThreshold: 50 * time.Millisecond, + baseBlockDuration: 100 * time.Millisecond, + reconnLimitForBan: 3, + metaChangeLimit: 2, + } +} + +type LoginFilterTestSuite struct { + suite.Suite + filter *loginFilter +} + +func (s *LoginFilterTestSuite) SetupTest() { + s.filter = newLoginFilterWithCfg(testAdvancedCfg()) +} + +func TestLoginFilterTestSuite(t *testing.T) { + suite.Run(t, new(LoginFilterTestSuite)) +} + +func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.True(s.filter.allowLogin(pubKey, meta)) + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(1, s.filter.logged[pubKey].sessionCounter) +} + +func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + limit := s.filter.cfg.reconnLimitForBan + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + + s.False(s.filter.allowLogin(pubKey, meta)) + s.Require().Contains(s.filter.logged, pubKey) + s.True(s.filter.logged[pubKey].isBanned) +} + +func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + limit := s.filter.cfg.reconnLimitForBan + baseBan := s.filter.cfg.baseBlockDuration + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + s.Require().Contains(s.filter.logged, pubKey) + s.True(s.filter.logged[pubKey].isBanned) + s.Equal(1, s.filter.logged[pubKey].banLevel) + firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond)) + + s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second) + s.filter.logged[pubKey].isBanned = false + + for i := 0; i <= limit; i++ { + s.filter.addLogin(pubKey, meta) + } + s.True(s.filter.logged[pubKey].isBanned) + s.Equal(2, s.filter.logged[pubKey].banLevel) + secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen) + expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1)) + s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond)) +} + +func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.filter.logged[pubKey] = &peerState{ + isBanned: true, + banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)), + } + + s.True(s.filter.allowLogin(pubKey, meta)) + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.False(s.filter.logged[pubKey].isBanned) +} + +func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() { + pubKey := "PUB_KEY_A" + meta := uint64(1) + + s.filter.logged[pubKey] = &peerState{ + currentHash: meta, + banLevel: 3, + lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration), + } + + s.filter.addLogin(pubKey, meta) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(0, s.filter.logged[pubKey].banLevel) +} + +func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() { + pubKey := "PUB_KEY_A" + limit := s.filter.cfg.metaChangeLimit + + for i := range limit { + s.filter.addLogin(pubKey, uint64(i+1)) + } + + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter) + + isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1)) + + s.False(isAllowed, "should block new meta hash after limit is reached") +} + +func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() { + pubKey := "PUB_KEY_A" + meta1 := uint64(1) + meta2 := uint64(2) + meta3 := uint64(3) + + s.filter.addLogin(pubKey, meta1) + s.filter.addLogin(pubKey, meta2) + s.Require().Contains(s.filter.logged, pubKey) + s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter) + s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window") + + s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second)) + + s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires") + + s.filter.addLogin(pubKey, meta3) + s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset") +} + +func BenchmarkHashingMethods(b *testing.B) { + meta := nbpeer.PeerSystemMeta{ + WtVersion: "1.25.1", + OSVersion: "Ubuntu 22.04.3 LTS", + KernelVersion: "5.15.0-76-generic", + Hostname: "prod-server-database-01", + SystemSerialNumber: "PC-1234567890", + NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}}, + } + pubip := "8.8.8.8" + + var resultString string + var resultUint uint64 + + b.Run("BuilderString", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultString = builderString(meta, pubip) + } + }) + + b.Run("FnvHashToString", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultString = fnvHashToString(meta, pubip) + } + }) + + b.Run("FnvHashToUint64 - used", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + resultUint = metaHash(meta, pubip) + } + }) + + _ = resultString + _ = resultUint +} + +func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string { + h := fnv.New64a() + + if len(meta.NetworkAddresses) != 0 { + for _, na := range meta.NetworkAddresses { + h.Write([]byte(na.Mac)) + } + } + + h.Write([]byte(meta.WtVersion)) + h.Write([]byte(meta.OSVersion)) + h.Write([]byte(meta.KernelVersion)) + h.Write([]byte(meta.Hostname)) + h.Write([]byte(meta.SystemSerialNumber)) + h.Write([]byte(pubip)) + + return strconv.FormatUint(h.Sum64(), 16) +} + +func builderString(meta nbpeer.PeerSystemMeta, pubip string) string { + mac := getMacAddress(meta.NetworkAddresses) + estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + + len(pubip) + len(mac) + 6 + + var b strings.Builder + b.Grow(estimatedSize) + + b.WriteString(meta.WtVersion) + b.WriteByte('|') + b.WriteString(meta.OSVersion) + b.WriteByte('|') + b.WriteString(meta.KernelVersion) + b.WriteByte('|') + b.WriteString(meta.Hostname) + b.WriteByte('|') + b.WriteString(meta.SystemSerialNumber) + b.WriteByte('|') + b.WriteString(pubip) + + return b.String() +} + +func getMacAddress(nas []nbpeer.NetworkAddress) string { + if len(nas) == 0 { + return "" + } + macs := make([]string, 0, len(nas)) + for _, na := range nas { + macs = append(macs, na.Mac) + } + return strings.Join(macs, "/") +} + +func BenchmarkLoginFilter_ParallelLoad(b *testing.B) { + filter := newLoginFilterWithCfg(testAdvancedCfg()) + numKeys := 100000 + pubKeys := make([]string, numKeys) + for i := range numKeys { + pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + + for pb.Next() { + key := pubKeys[r.Intn(numKeys)] + meta := r.Uint64() + + if filter.allowLogin(key, meta) { + filter.addLogin(key, meta) + } + } + }) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 6f9c2696f..caba58c8b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -121,8 +121,10 @@ type MockAccountManager struct { GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -953,3 +955,10 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n } return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } + +func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { + if am.AllowSyncFunc != nil { + return am.AllowSyncFunc(key, hash) + } + return true +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c4822aa62..df40014f0 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1609,7 +1609,6 @@ func Test_LoginPeer(t *testing.T) { testCases := []struct { name string setupKey string - wireGuardPubKey string expectExtraDNSLabelsMismatch bool extraDNSLabels []string expectLoginError bool diff --git a/management/server/telemetry/grpc_metrics.go b/management/server/telemetry/grpc_metrics.go index 9840b4b32..d4301802f 100644 --- a/management/server/telemetry/grpc_metrics.go +++ b/management/server/telemetry/grpc_metrics.go @@ -15,8 +15,10 @@ const HighLatencyThreshold = time.Second * 7 type GRPCMetrics struct { meter metric.Meter syncRequestsCounter metric.Int64Counter + syncRequestsBlockedCounter metric.Int64Counter syncRequestHighLatencyCounter metric.Int64Counter loginRequestsCounter metric.Int64Counter + loginRequestsBlockedCounter metric.Int64Counter loginRequestHighLatencyCounter metric.Int64Counter getKeyRequestsCounter metric.Int64Counter activeStreamsGauge metric.Int64ObservableGauge @@ -36,6 +38,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of sync gRPC requests from blocked peers"), + ) + if err != nil { + return nil, err + } + syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter", metric.WithUnit("1"), metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"), @@ -52,6 +62,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return nil, err } + loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of login gRPC requests from blocked peers"), + ) + if err != nil { + return nil, err + } + loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter", metric.WithUnit("1"), metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"), @@ -107,8 +125,10 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro return &GRPCMetrics{ meter: meter, syncRequestsCounter: syncRequestsCounter, + syncRequestsBlockedCounter: syncRequestsBlockedCounter, syncRequestHighLatencyCounter: syncRequestHighLatencyCounter, loginRequestsCounter: loginRequestsCounter, + loginRequestsBlockedCounter: loginRequestsBlockedCounter, loginRequestHighLatencyCounter: loginRequestHighLatencyCounter, getKeyRequestsCounter: getKeyRequestsCounter, activeStreamsGauge: activeStreamsGauge, @@ -124,6 +144,11 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() { grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1) } +// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers +func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() { + grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) +} + // CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() { grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1) @@ -134,6 +159,11 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() { grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1) } +// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers +func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() { + grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1) +} + // CountLoginRequestDuration counts the duration of the login gRPC requests func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) { grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds()) diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 7660174d6..52d27b062 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -42,7 +42,10 @@ const ( // Type is a type of the Error type Type int32 -var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found") +var ( + ErrExtraSettingsNotFound = errors.New("extra settings not found") + ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in") +) // Error is an internal error type Error struct { From 6fc50a438f507989aedfa353f8d87906d61335f4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sat, 30 Aug 2025 12:46:54 +0200 Subject: [PATCH 10/41] [management] remove withContext from store methods (#4422) --- management/server/store/sql_store.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 6ef93f0d1..45561f950 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -914,7 +914,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { var account types.Account - result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account) + result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account) if result.Error != nil { return "", status.NewGetAccountFromStoreError(result.Error) } @@ -1399,7 +1399,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI PeerID: peerID, } - err := s.db.WithContext(ctx).Clauses(clause.OnConflict{ + err := s.db.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, DoNothing: true, }).Create(peer).Error @@ -1414,7 +1414,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI // RemovePeerFromGroup removes a peer from a group func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error { - err := s.db.WithContext(ctx). + err := s.db. Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error if err != nil { @@ -1427,7 +1427,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group // RemovePeerFromAllGroups removes a peer from all groups func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { - err := s.db.WithContext(ctx). + err := s.db. Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error if err != nil { @@ -2015,7 +2015,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return s.db.Transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } @@ -2706,7 +2706,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) { - tx := s.db.WithContext(ctx) + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } From 4d3dc3475dcfb7694a73f03393f2335ecfda65ed Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sat, 30 Aug 2025 12:47:13 +0200 Subject: [PATCH 11/41] [management] remove duplicated removal of groups on peer delete (#4421) --- management/server/grpcserver.go | 2 -- management/server/peer.go | 4 ---- 2 files changed, 6 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 95d4f9eee..27d54e6c2 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -999,8 +999,6 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (* return nil, mapError(ctx, err) } - s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) - log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start)) return &proto.Empty{}, nil diff --git a/management/server/peer.go b/management/server/peer.go index 8af71cbd2..3c2ebe6b6 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -368,10 +368,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { - return fmt.Errorf("failed to remove peer from groups: %w", err) - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) From d817584f52c71d0fcba7caf295ae8c8da2393da0 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 31 Aug 2025 12:19:56 -0300 Subject: [PATCH 12/41] [misc] fix Windows client and management bench tests (#4424) Windows tests had too many directories, causing issues to the payload via psexec. Also migrated all checked benchmarks to send data to grafana. --- .github/workflows/golang-test-linux.yml | 35 +++- .github/workflows/golang-test-windows.yml | 2 +- management/README.md | 3 + management/server/account_test.go | 60 +++--- .../peers_handler_benchmark_test.go | 19 +- .../setupkeys_handler_benchmark_test.go | 23 +-- .../users_handler_benchmark_test.go | 23 +-- .../setupkeys_handler_integration_test.go | 13 +- .../testing/testing_tools/channel/channel.go | 137 ++++++++++++++ .../http/testing/testing_tools/tools.go | 173 +++--------------- management/server/peer_test.go | 12 +- shared/management/client/rest/client_test.go | 4 +- 12 files changed, 282 insertions(+), 222 deletions(-) create mode 100644 management/server/http/testing/testing_tools/channel/channel.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 0013833c4..f7b4e238f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -382,6 +382,32 @@ jobs: store: [ 'sqlite', 'postgres' ] runs-on: ubuntu-22.04 steps: + - name: Create Docker network + run: docker network create promnet + + - name: Start Prometheus Pushgateway + run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway + + - name: Start Prometheus (for Pushgateway forwarding) + run: | + echo ' + global: + scrape_interval: 15s + scrape_configs: + - job_name: "pushgateway" + static_configs: + - targets: ["pushgateway:9091"] + remote_write: + - url: ${{ secrets.GRAFANA_URL }} + basic_auth: + username: ${{ secrets.GRAFANA_USER }} + password: ${{ secrets.GRAFANA_API_KEY }} + ' > prometheus.yml + + docker run -d --name prometheus --network promnet \ + -v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \ + -p 9090:9090 \ + prom/prometheus - name: Install Go uses: actions/setup-go@v5 with: @@ -428,9 +454,10 @@ jobs: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \ CI=true \ + GIT_BRANCH=${{ github.ref_name }} \ go test -tags devcert -run=^$ -bench=. \ - -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/... ./shared/management/... + -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ + -timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http) api_benchmark: name: "Management / Benchmark (API)" @@ -521,7 +548,7 @@ jobs: -run=^$ \ -bench=. \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \ - -timeout 20m ./management/... ./shared/management/... + -timeout 20m ./management/server/http/... api_integration_test: name: "Management / Integration" @@ -571,4 +598,4 @@ jobs: CI=true \ go test -tags=integration \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - -timeout 20m ./management/... ./shared/management/... \ No newline at end of file + -timeout 20m ./management/server/http/... diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index d9ff0a84b..2083c0721 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -63,7 +63,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV - name: test run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" diff --git a/management/README.md b/management/README.md index 1122a9e76..c70285d43 100644 --- a/management/README.md +++ b/management/README.md @@ -111,3 +111,6 @@ Generate gRpc code: #!/bin/bash protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=. ``` + + + diff --git a/management/server/account_test.go b/management/server/account_test.go index 252be23f7..66cf93286 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -15,6 +15,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/prometheus/client_golang/prometheus/push" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -3046,19 +3048,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3121,19 +3118,14 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3196,24 +3188,44 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } } +func TestMain(m *testing.M) { + exitCode := m.Run() + + if exitCode == 0 && os.Getenv("CI") == "true" { + runID := os.Getenv("GITHUB_RUN_ID") + storeEngine := os.Getenv("NETBIRD_STORE_ENGINE") + err := push.New("http://localhost:9091", "account_manager_benchmark"). + Collector(testing_tools.BenchmarkDuration). + Grouping("ci_run", runID). + Grouping("store_engine", storeEngine). + Push() + if err != nil { + log.Printf("Failed to push metrics: %v", err) + } else { + time.Sleep(1 * time.Minute) + _ = push.New("http://localhost:9091", "account_manager_benchmark"). + Grouping("ci_run", runID). + Grouping("store_engine", storeEngine). + Delete() + } + } + + os.Exit(exitCode) +} + func Test_GetCreateAccountByPrivateDomain(t *testing.T) { manager, err := createManager(t) if err != nil { diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 52737e4eb..3fe3fe809 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) const modulePeers = "peers" @@ -47,7 +48,7 @@ func BenchmarkUpdatePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -65,7 +66,7 @@ func BenchmarkUpdatePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate) }) } } @@ -82,7 +83,7 @@ func BenchmarkGetOnePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -92,7 +93,7 @@ func BenchmarkGetOnePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne) }) } } @@ -109,7 +110,7 @@ func BenchmarkGetAllPeers(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -119,7 +120,7 @@ func BenchmarkGetAllPeers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll) }) } } @@ -136,7 +137,7 @@ func BenchmarkDeletePeer(b *testing.B) { for name, bc := range benchCasesPeers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -146,7 +147,7 @@ func BenchmarkDeletePeer(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 9404c4ee4..36b226db0 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -17,8 +17,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Map to store peers, groups, users, and setupKeys by name @@ -47,7 +48,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -69,7 +70,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate) }) } } @@ -86,7 +87,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -109,7 +110,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate) }) } } @@ -126,7 +127,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -136,7 +137,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne) }) } } @@ -153,7 +154,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -163,7 +164,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll) }) } } @@ -180,7 +181,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) b.ResetTimer() @@ -190,7 +191,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 844b3e7a6..2868a20bd 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -18,8 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) const moduleUsers = "users" @@ -46,7 +47,7 @@ func BenchmarkUpdateUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -71,7 +72,7 @@ func BenchmarkUpdateUser(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate) }) } } @@ -84,18 +85,18 @@ func BenchmarkGetOneUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne) }) } } @@ -110,18 +111,18 @@ func BenchmarkGetAllUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) recorder := httptest.NewRecorder() b.ResetTimer() start := time.Now() + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) for i := 0; i < b.N; i++ { - req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId) apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll) }) } } @@ -136,7 +137,7 @@ func BenchmarkDeleteUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) + apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) recorder := httptest.NewRecorder() @@ -147,7 +148,7 @@ func BenchmarkDeleteUsers(b *testing.B) { apiHandler.ServeHTTP(recorder, req) } - testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) + testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete) }) } } diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 9f04e3c24..1079de4aa 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -15,9 +15,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" + "github.com/netbirdio/netbird/shared/management/http/api" ) func Test_SetupKeys_Create(t *testing.T) { @@ -287,7 +288,7 @@ func Test_SetupKeys_Create(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(user.name+" - "+tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -572,7 +573,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -751,7 +752,7 @@ func Test_SetupKeys_Get(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) @@ -903,7 +904,7 @@ func Test_SetupKeys_GetAll(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) @@ -1087,7 +1088,7 @@ func Test_SetupKeys_Delete(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) + apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go new file mode 100644 index 000000000..741f03f18 --- /dev/null +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -0,0 +1,137 @@ +package channel + +import ( + "context" + "errors" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/management-integrations/integrations" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/auth" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + http2 "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" +) + +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) + done := make(chan struct{}) + if validateUpdate { + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + } + + geoMock := &geolocation.Mock{} + validatorMock := server.MockIntegratedValidator{} + proxyController := integrations.NewController(store) + userManager := users.NewManager(store) + permissionsManager := permissions.NewManager(store) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + // @note this is required so that PAT's validate from store, but JWT's are mocked + authManager := auth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, + MarkPATUsedFunc: authManager.MarkPATUsed, + GetPATInfoFunc: authManager.GetPATInfo, + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + peersManager := peers.NewManager(store, permissionsManager) + + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { + userAuth := nbcontext.UserAuth{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + userAuth.UserId = token + userAuth.AccountId = "testAccountId" + userAuth.Domain = "test.com" + userAuth.DomainCategory = "private" + case "otherUserId": + userAuth.UserId = "otherUserId" + userAuth.AccountId = "otherAccountId" + userAuth.Domain = "other.com" + userAuth.DomainCategory = "private" + case "invalidToken": + return userAuth, nil, errors.New("invalid token") + } + + jwtToken := jwt.New(jwt.SigningMethodHS256) + return userAuth, jwtToken, nil +} diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 1b82b156e..b7a63b104 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -3,7 +3,6 @@ package testing_tools import ( "bytes" "context" - "errors" "fmt" "io" "net" @@ -14,32 +13,12 @@ import ( "testing" "time" - "github.com/golang-jwt/jwt/v5" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/users" - - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/management/server/groups" - nbhttp "github.com/netbirdio/netbird/management/server/http" - "github.com/netbirdio/netbird/management/server/networks" - "github.com/netbirdio/netbird/management/server/networks/resources" - "github.com/netbirdio/netbird/management/server/networks/routers" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" ) @@ -106,90 +85,6 @@ type PerformanceMetrics struct { MaxMsPerOpCICD float64 } -func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { - store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) - if err != nil { - t.Fatalf("Failed to create test store: %v", err) - } - t.Cleanup(cleanup) - - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - if err != nil { - t.Fatalf("Failed to create metrics: %v", err) - } - - peersUpdateManager := server.NewPeersUpdateManager(nil) - updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) - done := make(chan struct{}) - if validateUpdate { - go func() { - if expectedPeerUpdate != nil { - peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) - } else { - peerShouldNotReceiveUpdate(t, updMsg) - } - close(done) - }() - } - - geoMock := &geolocation.Mock{} - validatorMock := server.MockIntegratedValidator{} - proxyController := integrations.NewController(store) - userManager := users.NewManager(store) - permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) - if err != nil { - t.Fatalf("Failed to create manager: %v", err) - } - - // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ - ValidateAndParseTokenFunc: mockValidateAndParseToken, - EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, - MarkPATUsedFunc: authManager.MarkPATUsed, - GetPATInfoFunc: authManager.GetPATInfo, - } - - networksManagerMock := networks.NewManagerMock() - resourcesManagerMock := resources.NewManagerMock() - routersManagerMock := routers.NewManagerMock() - groupsManagerMock := groups.NewManagerMock() - peersManager := peers.NewManager(store, permissionsManager) - - apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) - if err != nil { - t.Fatalf("Failed to create API handler: %v", err) - } - - return apiHandler, am, done -} - -func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) { - t.Helper() - select { - case msg := <-updateMessage: - t.Errorf("Unexpected message received: %+v", msg) - case <-time.After(500 * time.Millisecond): - return - } -} - -func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { - t.Helper() - - select { - case msg := <-updateMessage: - if msg == nil { - t.Errorf("Received nil update message, expected valid message") - } - assert.Equal(t, expected, msg) - case <-time.After(500 * time.Millisecond): - t.Errorf("Timed out waiting for update message") - } -} - func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { t.Helper() @@ -222,11 +117,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta return content, expectedStatus == http.StatusOK } -func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { +func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) { b.Helper() ctx := context.Background() - account, err := am.GetAccount(ctx, TestAccountId) + acc, err := am.GetAccount(ctx, TestAccountId) if err != nil { b.Fatalf("Failed to get account: %v", err) } @@ -242,23 +137,23 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } - account.Peers[peer.ID] = peer + acc.Peers[peer.ID] = peer } // Create users for i := 0; i < users; i++ { user := &types.User{ Id: fmt.Sprintf("olduser-%d", i), - AccountID: account.Id, + AccountID: acc.Id, Role: types.UserRoleUser, } - account.Users[user.Id] = user + acc.Users[user.Id] = user } for i := 0; i < setupKeys; i++ { key := &types.SetupKey{ Id: fmt.Sprintf("oldkey-%d", i), - AccountID: account.Id, + AccountID: acc.Id, AutoGroups: []string{"someGroupID"}, UpdatedAt: time.Now().UTC(), ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), @@ -266,11 +161,11 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Type: "reusable", UsageLimit: 0, } - account.SetupKeys[key.Id] = key + acc.SetupKeys[key.Id] = key } // Create groups and policies - account.Policies = make([]*types.Policy, 0, groups) + acc.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &types.Group{ @@ -281,7 +176,7 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - account.Groups[groupID] = group + acc.Groups[groupID] = group // Create a policy for this group policy := &types.Policy{ @@ -301,10 +196,10 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro }, }, } - account.Policies = append(account.Policies, policy) + acc.Policies = append(acc.Policies, policy) } - account.PostureChecks = []*posture.Checks{ + acc.PostureChecks = []*posture.Checks{ { ID: "PostureChecksAll", Name: "All", @@ -316,52 +211,38 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro }, } - err = am.Store.SaveAccount(context.Background(), account) + store := am.GetStore() + + err = store.SaveAccount(context.Background(), acc) if err != nil { b.Fatalf("Failed to save account: %v", err) } } -func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { +func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) { b.Helper() - branch := os.Getenv("GIT_BRANCH") - if branch == "" { - b.Fatalf("environment variable GIT_BRANCH is not set") - } - if recorder.Code != http.StatusOK { b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code) } + EvaluateBenchmarkResults(b, testCase, duration, module, operation) + +} + +func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) { + b.Helper() + + branch := os.Getenv("GIT_BRANCH") + if branch == "" && os.Getenv("CI") == "true" { + b.Fatalf("environment variable GIT_BRANCH is not set") + } + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch) gauge.Set(msPerOp) b.ReportMetric(msPerOp, "ms/op") - -} - -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} - - switch token { - case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": - userAuth.UserId = token - userAuth.AccountId = "testAccountId" - userAuth.Domain = "test.com" - userAuth.DomainCategory = "private" - case "otherUserId": - userAuth.UserId = "otherUserId" - userAuth.AccountId = "otherAccountId" - userAuth.Domain = "other.com" - userAuth.DomainCategory = "private" - case "invalidToken": - return userAuth, nil, errors.New("invalid token") - } - - jwtToken := jwt.New(jwt.SigningMethodHS256) - return userAuth, jwtToken, nil } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index df40014f0..c77bf5e25 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -26,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" @@ -989,19 +990,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - minExpected := bc.minMsPerOpLocal maxExpected := bc.maxMsPerOpLocal if os.Getenv("CI") == "true" { - minExpected = bc.minMsPerOpCICD maxExpected = bc.maxMsPerOpCICD + testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer") } - if msPerOp < minExpected { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) - } - - if msPerOp > (maxExpected * 1.1) { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) + if msPerOp > maxExpected { + b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } diff --git a/shared/management/client/rest/client_test.go b/shared/management/client/rest/client_test.go index 56c859652..54a0290d0 100644 --- a/shared/management/client/rest/client_test.go +++ b/shared/management/client/rest/client_test.go @@ -8,8 +8,8 @@ import ( "net/http/httptest" "testing" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel" "github.com/netbirdio/netbird/shared/management/client/rest" - "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" ) func withMockClient(callback func(*rest.Client, *http.ServeMux)) { @@ -26,7 +26,7 @@ func ptr[T any, PT *T](x T) PT { func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) { t.Helper() - handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) + handler, _, _ := channel.BuildApiBlackBoxWithDBState(t, "../../../../management/server/testdata/store.sql", nil, false) server := httptest.NewServer(handler) defer server.Close() c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp") From 21368b38d9e8c5e6e55881a069e912bc552b9b47 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 1 Sep 2025 10:42:01 +0200 Subject: [PATCH 13/41] [client] Update Pion ICE to the latest version (#4388) - Update Pion version - Update protobuf version --- client/iface/bind/ice_bind.go | 2 +- client/iface/bind/udp_mux.go | 4 +- client/iface/bind/udp_mux_universal.go | 2 +- client/internal/engine.go | 4 +- client/internal/peer/conn.go | 2 +- client/internal/peer/guard/ice_monitor.go | 2 +- client/internal/peer/ice/StunTurn.go | 2 +- client/internal/peer/ice/agent.go | 2 +- client/internal/peer/ice/config.go | 2 +- client/internal/peer/signaler.go | 2 +- client/internal/peer/worker_ice.go | 52 +++++++++++++++++------ client/internal/relay/relay.go | 2 +- go.mod | 16 ++++--- go.sum | 27 ++++++++---- 14 files changed, 80 insertions(+), 41 deletions(-) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index f23be406e..359d2129b 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -8,7 +8,7 @@ import ( "runtime" "sync" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/net/ipv4" diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 29e5d7937..db7494405 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -8,9 +8,9 @@ import ( "strings" "sync" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/logging" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go index b06da6712..a1f517dcd 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/bind/udp_mux_universal.go @@ -15,7 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/pion/logging" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/transport/v3" "github.com/netbirdio/netbird/client/iface/bufsize" diff --git a/client/internal/engine.go b/client/internal/engine.go index c610de41e..ca01bfd14 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -18,8 +18,8 @@ import ( "time" "github.com/hashicorp/go-multierror" - "github.com/pion/ice/v3" - "github.com/pion/stun/v2" + "github.com/pion/ice/v4" + "github.com/pion/stun/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index a6cf3cd25..224a8144c 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index b9c9aa134..70850e6eb 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -6,7 +6,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" diff --git a/client/internal/peer/ice/StunTurn.go b/client/internal/peer/ice/StunTurn.go index 63ee8c713..a389f5444 100644 --- a/client/internal/peer/ice/StunTurn.go +++ b/client/internal/peer/ice/StunTurn.go @@ -3,7 +3,7 @@ package ice import ( "sync/atomic" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" ) type StunTurn atomic.Value diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 58c1bf634..e80c98884 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -4,7 +4,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/logging" "github.com/pion/randutil" log "github.com/sirupsen/logrus" diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go index dd854a605..dd5d67403 100644 --- a/client/internal/peer/ice/config.go +++ b/client/internal/peer/ice/config.go @@ -1,7 +1,7 @@ package ice import ( - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" ) type Config struct { diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index ca1d421a5..b28906625 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -1,7 +1,7 @@ package peer import ( - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4f00af829..e80641770 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/pion/ice/v3" + "github.com/pion/ice/v4" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" @@ -214,10 +214,6 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil { - return nil, fmt.Errorf("failed setting binding response callback: %w", err) - } - return agent, nil } @@ -253,6 +249,11 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.closeAgent(agent, w.agentDialerCancel) return } + if pair == nil { + w.log.Warnf("selected candidate pair is nil, cannot proceed") + w.closeAgent(agent, w.agentDialerCancel) + return + } if !isRelayCandidate(pair.Local) { // dynamically set remote WireGuard port if other side specified a different one from the default one @@ -379,6 +380,27 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) + + w.muxAgent.Lock() + + pair, err := w.agent.GetSelectedCandidatePair() + if err != nil { + w.log.Warnf("failed to get selected candidate pair: %s", err) + w.muxAgent.Unlock() + return + } + if pair == nil { + w.log.Warnf("selected candidate pair is nil, cannot proceed") + w.muxAgent.Unlock() + return + } + w.muxAgent.Unlock() + + duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second)) + if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { + w.log.Debugf("failed to update latency for peer: %s", err) + return + } } func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { @@ -400,13 +422,6 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia } } -func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) { - if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil { - w.log.Debugf("failed to update latency for peer: %s", err) - return - } -} - func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { return true @@ -425,7 +440,7 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSaf func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() - return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ Network: candidate.NetworkType().String(), Address: candidate.Address(), Port: relatedAdd.Port, @@ -433,6 +448,17 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive RelAddr: relatedAdd.Address, RelPort: relatedAdd.Port, }) + if err != nil { + return nil, err + } + + for _, e := range candidate.Extensions() { + if err := ec.AddExtension(e); err != nil { + return nil, err + } + } + + return ec, nil } func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 6e1f83a9a..8c3d5a571 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/pion/stun/v2" + "github.com/pion/stun/v3" "github.com/pion/turn/v3" log "github.com/sirupsen/logrus" diff --git a/go.mod b/go.mod index 33aea7618..e840fb343 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.27.6 - github.com/pion/ice/v3 v3.0.2 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.7.0 @@ -24,7 +23,7 @@ require ( golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/grpc v1.73.0 - google.golang.org/protobuf v1.36.6 + google.golang.org/protobuf v1.36.8 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -69,10 +68,12 @@ require ( github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 - github.com/pion/logging v0.2.2 + github.com/pion/ice/v4 v4.0.0-00010101000000-000000000000 + github.com/pion/logging v0.2.4 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 - github.com/pion/transport/v3 v3.0.1 + github.com/pion/stun/v3 v3.0.0 + github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 github.com/quic-go/quic-go v0.48.2 @@ -212,8 +213,10 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect - github.com/pion/mdns v0.0.12 // indirect + github.com/pion/dtls/v3 v3.0.7 // indirect + github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/transport/v2 v2.2.4 // indirect + github.com/pion/turn/v4 v4.1.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect @@ -229,6 +232,7 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/wlynxg/anet v0.0.3 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect @@ -257,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e +replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index 562548edd..e9c894354 100644 --- a/go.sum +++ b/go.sum @@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= -github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= +github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= +github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg= github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= @@ -546,21 +546,29 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= +github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= +github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns v0.0.12 h1:CiMYlY+O0azojWDmxdNr7ADGrnZ+V6Ilfner+6mSVK8= -github.com/pion/mdns v0.0.12/go.mod h1:VExJjv8to/6Wqm1FXK+Ii/Z9tsVk/F5sD/N70cnYFbk= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM= +github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= +github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= +github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= -github.com/pion/transport/v3 v3.0.1 h1:gDTlPJwROfSfz6QfSi0ZmeCSkFcnWWiiR9ES0ouANiM= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= +github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= +github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= +github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= +github.com/pion/turn/v4 v4.1.1/go.mod h1:2123tHk1O++vmjI5VSD0awT50NywDAq5A2NNNU4Jjs8= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -681,6 +689,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= +github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -1167,11 +1177,10 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= From d39fcfd62ab7cf956a30249da40bba53f8b8feab Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 1 Sep 2025 13:00:45 -0300 Subject: [PATCH 14/41] [management] Add user approval (#4411) This PR adds user approval functionality to the management system, allowing administrators to manually approve new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. Adds UserApprovalRequired setting to control manual user approval requirement Introduces user approval and rejection endpoints with corresponding business logic Prevents pending approval users from adding peers or logging in --- management/server/account.go | 25 ++- management/server/account/manager.go | 2 + management/server/account_test.go | 90 +++++++++ management/server/activity/codes.go | 4 + .../handlers/accounts/accounts_handler.go | 6 +- .../accounts/accounts_handler_test.go | 4 +- .../http/handlers/users/users_handler.go | 87 +++++++-- .../http/handlers/users/users_handler_test.go | 134 ++++++++++++- management/server/mock_server/account_mock.go | 16 ++ management/server/peer.go | 3 + management/server/peer_test.go | 183 ++++++++++++++++++ management/server/permissions/manager.go | 6 +- management/server/types/settings.go | 4 + management/server/types/user.go | 46 +++-- management/server/user.go | 74 +++++++ management/server/user_test.go | 114 +++++++++++ shared/management/http/api/openapi.yml | 67 +++++++ shared/management/http/api/types.gen.go | 20 +- shared/management/status/error.go | 5 + 19 files changed, 842 insertions(+), 48 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index b57550f14..d9638b41a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1136,7 +1136,18 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID - err := am.Store.SaveUser(ctx, newUser) + + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID) + if err != nil { + return "", err + } + + if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired { + newUser.Blocked = true + newUser.PendingApproval = true + } + + err = am.Store.SaveUser(ctx, newUser) if err != nil { return "", err } @@ -1146,7 +1157,11 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, return "", err } - am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) + if newUser.PendingApproval { + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true}) + } else { + am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil) + } return domainAccountID, nil } @@ -1795,6 +1810,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, + Extra: &types.ExtraSettings{ + UserApprovalRequired: true, + }, }, Onboarding: types.AccountOnboarding{ OnboardingFlowPending: true, @@ -1901,6 +1919,9 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, + Extra: &types.ExtraSettings{ + UserApprovalRequired: true, + }, }, } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index c7329a1da..30fbbbc3e 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -32,6 +32,8 @@ type Manager interface { DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) + RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 66cf93286..81a921bf9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3606,3 +3606,93 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { require.Error(t, err, "should fail with invalid peer ID") }) } + +func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create a domain-based account with user approval enabled + existingAccountID := "existing-account" + account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false) + account.Settings.Extra = &types.ExtraSettings{ + UserApprovalRequired: true, + } + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Set the account as domain primary account + account.IsDomainPrimaryAccount = true + account.DomainCategory = types.PrivateCategory + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Test adding new user to existing account with approval required + newUserID := "new-user-id" + userAuth := nbcontext.UserAuth{ + UserId: newUserID, + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + + acc, err := manager.Store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain") + require.Equal(t, "example.com", acc.Domain, "Account domain should match") + + returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) + require.NoError(t, err) + require.Equal(t, existingAccountID, returnedAccountID) + + // Verify user was created with pending approval + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) + require.NoError(t, err) + assert.True(t, user.Blocked, "User should be blocked when approval is required") + assert.True(t, user.PendingApproval, "User should be pending approval") + assert.Equal(t, existingAccountID, user.AccountID) +} + +func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create a domain-based account without user approval + ownerUserAuth := nbcontext.UserAuth{ + UserId: "owner-user", + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth) + require.NoError(t, err) + + // Modify the account to disable user approval + account, err := manager.Store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + account.Settings.Extra = &types.ExtraSettings{ + UserApprovalRequired: false, + } + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Test adding new user to existing account without approval required + newUserID := "new-user-id" + userAuth := nbcontext.UserAuth{ + UserId: newUserID, + Domain: "example.com", + DomainCategory: types.PrivateCategory, + } + + returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth) + require.NoError(t, err) + require.Equal(t, existingAccountID, returnedAccountID) + + // Verify user was created without pending approval + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID) + require.NoError(t, err) + assert.False(t, user.Blocked, "User should not be blocked when approval is not required") + assert.False(t, user.PendingApproval, "User should not be pending approval") + assert.Equal(t, existingAccountID, user.AccountID) +} diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 6f9619597..5c5989f84 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -177,6 +177,8 @@ const ( AccountNetworkRangeUpdated Activity = 87 PeerIPUpdated Activity = 88 + UserApproved Activity = 89 + UserRejected Activity = 90 AccountDeleted Activity = 99999 ) @@ -284,6 +286,8 @@ var activityMap = map[Activity]Code{ AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"}, PeerIPUpdated: {"Peer IP updated", "peer.ip.update"}, + UserApproved: {"User approved", "user.approve"}, + UserRejected: {"User rejected", "user.reject"}, } // StringCode returns a string code of the activity diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 9f2afe29d..f1552d0ea 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -11,11 +11,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -198,6 +198,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.Extra != nil { settings.Extra = &types.ExtraSettings{ PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: req.Settings.Extra.UserApprovalRequired, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, @@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, + UserApprovalRequired: settings.Extra.UserApprovalRequired, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsGroups: settings.Extra.FlowGroups, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 1dad33a6f..4b9b79fdc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -15,11 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initAccountsTestData(t *testing.T, account *types.Account) *handler { diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index bcd637db4..4e03e5e9b 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,11 +9,11 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" nbcontext "github.com/netbirdio/netbird/management/server/context" ) @@ -31,6 +31,8 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") addUsersTokensEndpoint(accountManager, router) } @@ -323,17 +325,76 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { } isCurrent := user.ID == currenUserID + return &api.User{ - Id: user.ID, - Name: user.Name, - Email: user.Email, - Role: user.Role, - AutoGroups: autoGroups, - Status: userStatus, - IsCurrent: &isCurrent, - IsServiceUser: &user.IsServiceUser, - IsBlocked: user.IsBlocked, - LastLogin: &user.LastLogin, - Issued: &user.Issued, + Id: user.ID, + Name: user.Name, + Email: user.Email, + Role: user.Role, + AutoGroups: autoGroups, + Status: userStatus, + IsCurrent: &isCurrent, + IsServiceUser: &user.IsServiceUser, + IsBlocked: user.IsBlocked, + LastLogin: &user.LastLogin, + Issued: &user.Issued, + PendingApproval: user.PendingApproval, } } + +// approveUser is a POST request to approve a user that is pending approval +func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + userResponse := toUserResponse(user, userAuth.UserId) + util.WriteJSONObject(r.Context(), w, userResponse) +} + +// rejectUser is a DELETE request to reject a user that is pending approval +func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w) + return + } + + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index f7dc81919..e08004218 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -16,13 +16,13 @@ import ( "github.com/stretchr/testify/require" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -725,3 +725,133 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri } return modules } + +func TestApproveUserEndpoint(t *testing.T) { + adminUser := &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + AccountID: existingAccountID, + AutoGroups: []string{}, + } + + pendingUser := &types.User{ + Id: "pending-user", + Role: types.UserRoleUser, + AccountID: existingAccountID, + Blocked: true, + PendingApproval: true, + AutoGroups: []string{}, + } + + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestingUser *types.User + }{ + { + name: "approve user as admin should return 200", + expectedStatus: 200, + expectedBody: true, + requestingUser: adminUser, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{} + am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + approvedUserInfo := &types.UserInfo{ + ID: pendingUser.Id, + Email: "pending@example.com", + Name: "Pending User", + Role: string(pendingUser.Role), + AutoGroups: []string{}, + IsServiceUser: false, + IsBlocked: false, + PendingApproval: false, + LastLogin: time.Now(), + Issued: types.UserIssuedAPI, + } + return approvedUserInfo, nil + } + + handler := newHandler(am) + router := mux.NewRouter() + router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") + + req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) + require.NoError(t, err) + + userAuth := nbcontext.UserAuth{ + AccountId: existingAccountID, + UserId: tc.requestingUser.Id, + } + ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + + if tc.expectedBody { + var response api.User + err = json.Unmarshal(rr.Body.Bytes(), &response) + require.NoError(t, err) + assert.Equal(t, "pending-user", response.Id) + assert.False(t, response.IsBlocked) + assert.False(t, response.PendingApproval) + } + }) + } +} + +func TestRejectUserEndpoint(t *testing.T) { + adminUser := &types.User{ + Id: "admin-user", + Role: types.UserRoleAdmin, + AccountID: existingAccountID, + AutoGroups: []string{}, + } + + tt := []struct { + name string + expectedStatus int + requestingUser *types.User + }{ + { + name: "reject user as admin should return 200", + expectedStatus: 200, + requestingUser: adminUser, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + am := &mock_server.MockAccountManager{} + am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + return nil + } + + handler := newHandler(am) + router := mux.NewRouter() + router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE") + + req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) + require.NoError(t, err) + + userAuth := nbcontext.UserAuth{ + AccountId: existingAccountID, + UserId: tc.requestingUser.Id, + } + ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + assert.Equal(t, tc.expectedStatus, rr.Code) + }) + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index caba58c8b..003385eb5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -95,6 +95,8 @@ type MockAccountManager struct { LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error + ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) + RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() account.ExternalCacheManager @@ -607,6 +609,20 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") } +func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + if am.ApproveUserFunc != nil { + return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented") +} + +func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + if am.RejectUserFunc != nil { + return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID) + } + return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented") +} + // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if am.GetNameServerGroupFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index 3c2ebe6b6..81f037499 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -489,6 +489,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found") } + if user.PendingApproval { + return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") + } groupsToAdd = user.AutoGroups opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c77bf5e25..31c309430 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2383,3 +2383,186 @@ func TestBufferUpdateAccountPeers(t *testing.T) { assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) } + +func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Try to add peer with pending approval user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + peer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + require.Error(t, err) + assert.Contains(t, err.Error(), "user pending approval cannot add peers") +} + +func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create regular user (not pending approval) + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + // Try to add peer with regular user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + peer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + require.NoError(t, err, "Regular user should be able to add peers") +} + +func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Create a peer using AddPeer method for the pending user (simulate existing peer) + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + // Set the user to not be pending initially so peer can be added + pendingUser.Blocked = false + pendingUser.PendingApproval = false + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Add peer using regular flow + newPeer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + WtVersion: "0.28.0", + }, + } + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + require.NoError(t, err) + + // Now set the user back to pending approval after peer was created + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Try to login with pending approval user + login := types.PeerLogin{ + WireGuardPubKey: existingPeer.Key, + UserID: pendingUser.Id, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.LoginPeer(context.Background(), login) + require.Error(t, err) + e, ok := status.FromError(err) + require.True(t, ok, "error is not a gRPC status error") + assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code") +} + +func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account + account := newAccountWithId(context.Background(), "test-account", "owner", "", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create regular user (not pending approval) + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + // Add peer using regular flow for the regular user + key, err := wgtypes.GenerateKey() + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + Key: key.PublicKey().String(), + Name: "test-peer", + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + WtVersion: "0.28.0", + }, + } + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + require.NoError(t, err) + + // Try to login with regular user + login := types.PeerLogin{ + WireGuardPubKey: existingPeer.Key, + UserID: regularUser.Id, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "test-peer", + OS: "linux", + }, + } + + _, _, _, err = manager.LoginPeer(context.Background(), login) + require.NoError(t, err, "Regular user should be able to login peers") +} diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 0ab244243..891fa59bb 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -54,10 +54,14 @@ func (m *managerImpl) ValidateUserPermissions( return false, status.NewUserNotFoundError(userID) } - if user.IsBlocked() { + if user.IsBlocked() && !user.PendingApproval { return false, status.NewUserBlockedError() } + if user.IsBlocked() && user.PendingApproval { + return false, status.NewUserPendingApprovalError() + } + if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil { return false, err } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 56c33da3b..b4afb2f5e 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -83,6 +83,9 @@ type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + // UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator + UserApprovalRequired bool + // IntegratedValidator is the string enum for the integrated validator type IntegratedValidator string // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations @@ -99,6 +102,7 @@ type ExtraSettings struct { func (e *ExtraSettings) Copy() *ExtraSettings { return &ExtraSettings{ PeerApprovalEnabled: e.PeerApprovalEnabled, + UserApprovalRequired: e.UserApprovalRequired, IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups), IntegratedValidator: e.IntegratedValidator, FlowEnabled: e.FlowEnabled, diff --git a/management/server/types/user.go b/management/server/types/user.go index 783fe14da..beb3586df 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -64,6 +64,7 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` + PendingApproval bool `json:"pending_approval"` IntegrationReference integration_reference.IntegrationReference `json:"-"` } @@ -84,6 +85,8 @@ type User struct { PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool + // PendingApproval indicates whether the user requires approval before being activated + PendingApproval bool // LastLogin is the last time the user logged in to IdP LastLogin *time.Time // CreatedAt records the time the user was created @@ -141,16 +144,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { if userData == nil { return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + PendingApproval: u.PendingApproval, }, nil } if userData.ID != u.Id { @@ -163,16 +167,17 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { } return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.GetLastLogin(), - Issued: u.Issued, + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.GetLastLogin(), + Issued: u.Issued, + PendingApproval: u.PendingApproval, }, nil } @@ -194,6 +199,7 @@ func (u *User) Copy() *User { ServiceUserName: u.ServiceUserName, PATs: pats, Blocked: u.Blocked, + PendingApproval: u.PendingApproval, LastLogin: u.LastLogin, CreatedAt: u.CreatedAt, Issued: u.Issued, diff --git a/management/server/user.go b/management/server/user.go index e5a4dbcea..04b2ce2d0 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -1207,3 +1207,77 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut return userWithPermissions, nil } + +// ApproveUser approves a user that is pending approval +func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.NewUserNotFoundError(targetUserID) + } + + if !user.PendingApproval { + return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) + } + + user.Blocked = false + user.PendingApproval = false + + err = am.Store.SaveUser(ctx, user) + if err != nil { + return nil, err + } + + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil) + + userInfo, err := am.getUserInfo(ctx, user, accountID) + if err != nil { + return nil, err + } + + return userInfo, nil +} + +// RejectUser rejects a user that is pending approval by deleting them +func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotFoundError(targetUserID) + } + + if !user.PendingApproval { + return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID) + } + + err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID) + if err != nil { + return err + } + + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil) + + return nil +} diff --git a/management/server/user_test.go b/management/server/user_test.go index 8ab0c1565..9638559f9 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1746,3 +1746,117 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { return permissions } + +func TestApproveUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account with admin and pending approval user + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create admin user + adminUser := types.NewAdminUser("admin-user") + adminUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), adminUser) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Test successful approval + approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.NoError(t, err) + assert.False(t, approvedUser.IsBlocked) + assert.False(t, approvedUser.PendingApproval) + + // Verify user is updated in store + updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) + require.NoError(t, err) + assert.False(t, updatedUser.Blocked) + assert.False(t, updatedUser.PendingApproval) + + // Test approval of non-pending user should fail + _, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not pending approval") + + // Test approval by non-admin should fail + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2.AccountID = account.Id + pendingUser2.Blocked = true + pendingUser2.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser2) + require.NoError(t, err) + + _, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) + require.Error(t, err) +} + +func TestRejectUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + // Create account with admin and pending approval user + account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false) + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + + // Create admin user + adminUser := types.NewAdminUser("admin-user") + adminUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), adminUser) + require.NoError(t, err) + + // Create user pending approval + pendingUser := types.NewRegularUser("pending-user") + pendingUser.AccountID = account.Id + pendingUser.Blocked = true + pendingUser.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser) + require.NoError(t, err) + + // Test successful rejection + err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id) + require.NoError(t, err) + + // Verify user is deleted from store + _, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id) + require.Error(t, err) + + // Test rejection of non-pending user should fail + regularUser := types.NewRegularUser("regular-user") + regularUser.AccountID = account.Id + err = manager.Store.SaveUser(context.Background(), regularUser) + require.NoError(t, err) + + err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id) + require.Error(t, err) + assert.Contains(t, err.Error(), "not pending approval") + + // Test rejection by non-admin should fail + pendingUser2 := types.NewRegularUser("pending-user-2") + pendingUser2.AccountID = account.Id + pendingUser2.Blocked = true + pendingUser2.PendingApproval = true + err = manager.Store.SaveUser(context.Background(), pendingUser2) + require.NoError(t, err) + + err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id) + require.Error(t, err) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index cf4b6d625..9a531b2ff 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -158,6 +158,10 @@ components: description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. type: boolean example: true + user_approval_required: + description: Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. + type: boolean + example: false network_traffic_logs_enabled: description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. type: boolean @@ -174,6 +178,7 @@ components: example: true required: - peer_approval_enabled + - user_approval_required - network_traffic_logs_enabled - network_traffic_logs_groups - network_traffic_packet_counter_enabled @@ -235,6 +240,10 @@ components: description: Is true if this user is blocked. Blocked users can't use the system type: boolean example: false + pending_approval: + description: Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. + type: boolean + example: false issued: description: How user was issued by API or Integration type: string @@ -249,6 +258,7 @@ components: - auto_groups - status - is_blocked + - pending_approval UserPermissions: type: object properties: @@ -2544,6 +2554,63 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/{userId}/approve: + post: + summary: Approve user + description: Approve a user that is pending approval + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The unique identifier of a user + responses: + '200': + description: Returns the approved user + content: + application/json: + schema: + "$ref": "#/components/schemas/User" + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/{userId}/reject: + delete: + summary: Reject user + description: Reject a user that is pending approval by removing them from the account + tags: [ Users ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The unique identifier of a user + responses: + '200': + description: User rejected successfully + content: {} + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/users/current: get: summary: Retrieve current user diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index cffc9e735..28b89633c 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -268,6 +268,9 @@ type AccountExtraSettings struct { // PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin. PeerApprovalEnabled bool `json:"peer_approval_enabled"` + + // UserApprovalRequired Enables manual approval for new users joining via domain matching. When enabled, users are blocked with pending approval status until explicitly approved by an admin. + UserApprovalRequired bool `json:"user_approval_required"` } // AccountOnboarding defines model for AccountOnboarding. @@ -1015,8 +1018,6 @@ type OSVersionCheck struct { // Peer defines model for Peer. type Peer struct { - // CreatedAt Peer creation date (UTC) - CreatedAt time.Time `json:"created_at"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval ApprovalRequired bool `json:"approval_required"` @@ -1032,6 +1033,9 @@ type Peer struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country CountryCode CountryCode `json:"country_code"` + // CreatedAt Peer creation date (UTC) + CreatedAt time.Time `json:"created_at"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1098,8 +1102,6 @@ type Peer struct { // PeerBatch defines model for PeerBatch. type PeerBatch struct { - // CreatedAt Peer creation date (UTC) - CreatedAt time.Time `json:"created_at"` // AccessiblePeersCount Number of accessible peers AccessiblePeersCount int `json:"accessible_peers_count"` @@ -1118,6 +1120,9 @@ type PeerBatch struct { // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country CountryCode CountryCode `json:"country_code"` + // CreatedAt Peer creation date (UTC) + CreatedAt time.Time `json:"created_at"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1774,8 +1779,11 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` - Permissions *UserPermissions `json:"permissions,omitempty"` + Name string `json:"name"` + + // PendingApproval Is true if this user requires approval before being activated. Only applicable for users joining via domain matching when user_approval_required is enabled. + PendingApproval bool `json:"pending_approval"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 52d27b062..1e914babb 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -113,6 +113,11 @@ func NewUserBlockedError() error { return Errorf(PermissionDenied, "user is blocked") } +// NewUserPendingApprovalError creates a new Error with PermissionDenied type for a blocked user pending approval +func NewUserPendingApprovalError() error { + return Errorf(PermissionDenied, "user is pending approval") +} + // NewPeerNotRegisteredError creates a new Error with Unauthenticated type unregistered peer func NewPeerNotRegisteredError() error { return Errorf(Unauthenticated, "peer is not registered") From 71e944fa57868ce38a30bdd7267de5dbd52a72cb Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 1 Sep 2025 19:51:06 +0200 Subject: [PATCH 15/41] [relay] Let relay accept any origin (#4426) --- relay/server/listener/ws/listener.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 332127660..12219e29b 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error { func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { connRemoteAddr := remoteAddr(r) - wsConn, err := websocket.Accept(w, r, nil) + + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) if err != nil { log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err) return From a8dcff69c2f491fcc3d0a864a385cb91c152aa34 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 4 Sep 2025 23:07:03 +0300 Subject: [PATCH 16/41] [management] Add peers manager to integrations (#4405) --- client/cmd/testutil_test.go | 29 ++++----- client/internal/engine_test.go | 17 ++--- client/server/server_test.go | 25 ++++---- go.mod | 2 +- go.sum | 4 +- management/internals/server/controllers.go | 2 +- shared/management/client/client_test.go | 73 +++++++++++----------- 7 files changed, 78 insertions(+), 74 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index e45443751..3f08d12ba 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -9,29 +9,26 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" + client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" + mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" - - "github.com/netbirdio/netbird/util" - - "google.golang.org/grpc" - - "github.com/netbirdio/management-integrations/integrations" - - clientProto "github.com/netbirdio/netbird/client/proto" - client "github.com/netbirdio/netbird/client/server" - mgmt "github.com/netbirdio/netbird/management/server" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto" sig "github.com/netbirdio/netbird/signal/server" + "github.com/netbirdio/netbird/util" ) func startTestingServices(t *testing.T) string { @@ -90,15 +87,19 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { return nil, nil } - iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - require.NoError(t, err) ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) - settingsMockManager := settings.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl) + peersmanager := peers.NewManager(store, permissionsManagerMock) + + iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, eventStore) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + require.NoError(t, err) + + settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() settingsMockManager.EXPECT(). diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index fc58dbdba..afe4622b3 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -19,17 +19,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" - wgdevice "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -45,9 +41,12 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -1555,7 +1554,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + permissionsManager := permissions.NewManager(store) + peersManager := peers.NewManager(store, permissionsManager) + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1572,7 +1574,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri Return(&types.ExtraSettings{}, nil). AnyTimes() - permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) diff --git a/client/server/server_test.go b/client/server/server_test.go index 24ff9fb0c..493c8601a 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,25 +10,24 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "github.com/netbirdio/management-integrations/integrations" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/groups" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -294,15 +293,19 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + permissionsManagerMock := permissions.NewMockManager(ctrl) + peersManager := peers.NewManager(store, permissionsManagerMock) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) - permissionsManagerMock := permissions.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) diff --git a/go.mod b/go.mod index e840fb343..68730bf53 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index e9c894354..4783c47eb 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f h1:r1gnjw0TfkaDLSCmAE3g5N5ulcd5WpFHaGrqQomCXP4= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index b351f3bc9..f9023b204 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -20,7 +20,7 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore()) + integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.PeersManager(), s.EventStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 3037b44bb..b04cdd96a 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -9,34 +9,30 @@ import ( "time" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - - "github.com/netbirdio/management-integrations/integrations" - - "github.com/netbirdio/netbird/encryption" - mgmt "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/mock_server" - mgmtProto "github.com/netbirdio/netbird/shared/management/proto" - + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmt "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -72,13 +68,29 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + permissionsManagerMock := permissions.NewMockManager(ctrl) + permissionsManagerMock. + EXPECT(). + ValidateUserPermissions( + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + gomock.Any(), + ). + Return(true, nil). + AnyTimes() + + peersManger := peers.NewManager(store, permissionsManagerMock) + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager. EXPECT(). @@ -95,19 +107,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - permissionsManagerMock := permissions.NewMockManager(ctrl) - permissionsManagerMock. - EXPECT(). - ValidateUserPermissions( - gomock.Any(), - gomock.Any(), - gomock.Any(), - gomock.Any(), - gomock.Any(), - ). - Return(true, nil). - AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) From dfebdf1444f559665bcd29c37306df154212f1dd Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Thu, 4 Sep 2025 18:00:10 -0300 Subject: [PATCH 17/41] [internal] Add missing assignment of iFaceDiscover when netstack is disabled (#4444) The internal updateInterfaces() function expects iFaceDiscover to not be nil --- client/internal/stdnet/stdnet.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 171cc42cb..4b031c05c 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri if netstack.IsEnabled() { n.iFaceDiscover = pionDiscover{} } else { - newMobileIFaceDiscover(iFaceDiscover) + n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover) } return n, n.UpdateInterfaces() } From 786ca6fc79fe3cfa76b1d374b4c2dbb610a15681 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 5 Sep 2025 11:02:29 +0200 Subject: [PATCH 18/41] Do not block Offer processing from relay worker (#4435) - do not miss ICE offers when relay worker busy - close p2p connection before recreate agent --- client/internal/peer/handshaker.go | 16 ++--- client/internal/peer/handshaker_listener.go | 62 +++++++++++++++++++ .../internal/peer/handshaker_listener_test.go | 39 ++++++++++++ client/internal/peer/worker_ice.go | 6 +- 4 files changed, 110 insertions(+), 13 deletions(-) create mode 100644 client/internal/peer/handshaker_listener.go create mode 100644 client/internal/peer/handshaker_listener_test.go diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 3cbf74cfd..42eaea683 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -43,13 +43,6 @@ type OfferAnswer struct { SessionID *ICESessionID } -func (oa *OfferAnswer) SessionIDString() string { - if oa.SessionID == nil { - return "unknown" - } - return oa.SessionID.String() -} - type Handshaker struct { mu sync.Mutex log *log.Entry @@ -57,7 +50,7 @@ type Handshaker struct { signaler *Signaler ice *WorkerICE relay *WorkerRelay - onNewOfferListeners []func(*OfferAnswer) + onNewOfferListeners []*OfferListener // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - h.onNewOfferListeners = append(h.onNewOfferListeners, offer) + l := NewOfferListener(offer) + h.onNewOfferListeners = append(h.onNewOfferListeners, l) } func (h *Handshaker) Listen(ctx context.Context) { @@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } for _, listener := range h.onNewOfferListeners { - listener(&remoteOfferAnswer) + listener.Notify(&remoteOfferAnswer) } h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) for _, listener := range h.onNewOfferListeners { - listener(&remoteOfferAnswer) + listener.Notify(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go new file mode 100644 index 000000000..e2d3f3f38 --- /dev/null +++ b/client/internal/peer/handshaker_listener.go @@ -0,0 +1,62 @@ +package peer + +import ( + "sync" +) + +type callbackFunc func(remoteOfferAnswer *OfferAnswer) + +func (oa *OfferAnswer) SessionIDString() string { + if oa.SessionID == nil { + return "unknown" + } + return oa.SessionID.String() +} + +type OfferListener struct { + fn callbackFunc + running bool + latest *OfferAnswer + mu sync.Mutex +} + +func NewOfferListener(fn callbackFunc) *OfferListener { + return &OfferListener{ + fn: fn, + } +} + +func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { + o.mu.Lock() + defer o.mu.Unlock() + + // Store the latest offer + o.latest = remoteOfferAnswer + + // If already running, the running goroutine will pick up this latest value + if o.running { + return + } + + // Start processing + o.running = true + + // Process in a goroutine to avoid blocking the caller + go func(remoteOfferAnswer *OfferAnswer) { + for { + o.fn(remoteOfferAnswer) + + o.mu.Lock() + if o.latest == nil { + // No more work to do + o.running = false + o.mu.Unlock() + return + } + remoteOfferAnswer = o.latest + // Clear the latest to mark it as being processed + o.latest = nil + o.mu.Unlock() + } + }(remoteOfferAnswer) +} diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go new file mode 100644 index 000000000..8363741a5 --- /dev/null +++ b/client/internal/peer/handshaker_listener_test.go @@ -0,0 +1,39 @@ +package peer + +import ( + "testing" + "time" +) + +func Test_newOfferListener(t *testing.T) { + dummyOfferAnswer := &OfferAnswer{} + runChan := make(chan struct{}, 10) + + longRunningFn := func(remoteOfferAnswer *OfferAnswer) { + time.Sleep(1 * time.Second) + runChan <- struct{}{} + } + + hl := NewOfferListener(longRunningFn) + + hl.Notify(dummyOfferAnswer) + hl.Notify(dummyOfferAnswer) + hl.Notify(dummyOfferAnswer) + + // Wait for exactly 2 callbacks + for i := 0; i < 2; i++ { + select { + case <-runChan: + case <-time.After(3 * time.Second): + t.Fatal("Timeout waiting for callback") + } + } + + // Verify no additional callbacks happen + select { + case <-runChan: + t.Fatal("Unexpected additional callback") + case <-time.After(100 * time.Millisecond): + t.Log("Correctly received exactly 2 callbacks") + } +} diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index e80641770..896c55b6c 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -122,7 +122,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Warnf("failed to close ICE agent: %s", err) } w.agent = nil - // todo consider to switch to Relay connection while establishing a new ICE connection } var preferredCandidateTypes []ice.CandidateType @@ -410,7 +409,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected return - case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: + case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: + // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to + // notify the conn.onICEStateDisconnected changes to update the current used priority + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() From d33f88df82f8aa6743001950fd9254f07c469126 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:11:23 +0200 Subject: [PATCH 19/41] [management] only allow user devices to be expired (#4445) --- management/server/account.go | 4 +++- management/server/user.go | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/account.go b/management/server/account.go index d9638b41a..ee9f294a4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1714,7 +1714,9 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err) continue } - peers = append(peers, peer) + if peer.UserID != "" { + peers = append(peers, peer) + } } if len(peers) > 0 { err := am.expireAndUpdatePeers(ctx, accountID, peers) diff --git a/management/server/user.go b/management/server/user.go index 04b2ce2d0..3c7c3f433 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -942,6 +942,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + if peer.UserID == "" { + // we do not want to expire peers that are added via setup key + continue + } + if peer.Status.LoginExpired { continue } From ad8fcda67b1ad391491eae2f42b9bacd88c8b790 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 6 Sep 2025 10:49:28 +0200 Subject: [PATCH 20/41] [client] Move some sys info to static place (#4446) This PR refactors the system information collection code by moving static system information gathering to a dedicated location and separating platform-specific implementations. The primary goal is to improve code organization and maintainability by centralizing static info collection logic. Key changes: - Centralized static info collection into dedicated files with platform-specific implementations - Moved `StaticInfo` struct definition to the main static_info.go file - Added async initialization function `UpdateStaticInfoAsync()` across all platforms --- client/cmd/login.go | 2 +- client/cmd/service_controller.go | 2 +- client/system/info.go | 15 -- client/system/info_android.go | 5 + client/system/info_darwin.go | 6 +- client/system/info_freebsd.go | 5 + client/system/info_ios.go | 5 + client/system/info_linux.go | 6 +- client/system/info_windows.go | 161 ++--------------- client/system/static_info.go | 42 ++--- client/system/static_info_stub.go | 8 - client/system/static_info_update.go | 35 ++++ client/system/static_info_update_windows.go | 184 ++++++++++++++++++++ 13 files changed, 278 insertions(+), 198 deletions(-) delete mode 100644 client/system/static_info_stub.go create mode 100644 client/system/static_info_update.go create mode 100644 client/system/static_info_update_windows.go diff --git a/client/cmd/login.go b/client/cmd/login.go index 92de6abdb..3ac211805 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, } // update host's static platform and system information - system.UpdateStaticInfo() + system.UpdateStaticInfoAsync() configFilePath, err := activeProf.FilePath() if err != nil { diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 50fb35d5e..0545ce6b7 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error { log.Info("starting NetBird service") //nolint // Collect static system and platform information - system.UpdateStaticInfo() + system.UpdateStaticInfoAsync() // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. p.serv = grpc.NewServer() diff --git a/client/system/info.go b/client/system/info.go index ea3f6063a..ceb1682f3 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -95,14 +95,6 @@ func (i *Info) SetFlags( i.LazyConnectionEnabled = lazyConnectionEnabled } -// StaticInfo is an object that contains machine information that does not change -type StaticInfo struct { - SystemSerialNumber string - SystemProductName string - SystemManufacturer string - Environment Environment -} - // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) @@ -195,10 +187,3 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro return info, nil } - -// UpdateStaticInfo asynchronously updates static system and platform information -func UpdateStaticInfo() { - go func() { - _ = updateStaticInfo() - }() -} diff --git a/client/system/info_android.go b/client/system/info_android.go index 56fe0741d..78895bfa8 100644 --- a/client/system/info_android.go +++ b/client/system/info_android.go @@ -15,6 +15,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { kernel := "android" diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index f105ada60..caa344737 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -19,6 +19,10 @@ import ( "github.com/netbirdio/netbird/version" ) +func UpdateStaticInfoAsync() { + go updateStaticInfo() +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { utsname := unix.Utsname{} @@ -41,7 +45,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_freebsd.go b/client/system/info_freebsd.go index bed6711de..8e1353151 100644 --- a/client/system/info_freebsd.go +++ b/client/system/info_freebsd.go @@ -18,6 +18,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { out := _getInfo() diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 897ec0a35..705c37920 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -10,6 +10,11 @@ import ( "github.com/netbirdio/netbird/version" ) +// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { diff --git a/client/system/info_linux.go b/client/system/info_linux.go index 9bfc82009..6c7a23b95 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -23,6 +23,10 @@ var ( getSystemInfo = defaultSysInfoImplementation ) +func UpdateStaticInfoAsync() { + go updateStaticInfo() +} + // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { info := _getInfo() @@ -48,7 +52,7 @@ func GetInfo(ctx context.Context) *Info { } start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 6f05ded20..e67356f57 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -2,66 +2,48 @@ package system import ( "context" - "fmt" "os" "runtime" - "strings" "time" log "github.com/sirupsen/logrus" - "github.com/yusufpapurcu/wmi" - "golang.org/x/sys/windows/registry" "github.com/netbirdio/netbird/version" ) -type Win32_OperatingSystem struct { - Caption string -} - -type Win32_ComputerSystem struct { - Manufacturer string -} - -type Win32_ComputerSystemProduct struct { - Name string -} - -type Win32_BIOS struct { - SerialNumber string +func UpdateStaticInfoAsync() { + go updateStaticInfo() } // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - osName, osVersion := getOSNameAndVersion() - buildVersion := getBuildVersion() - - addrs, err := networkAddresses() - if err != nil { - log.Warnf("failed to discover network addresses: %s", err) - } - start := time.Now() - si := updateStaticInfo() + si := getStaticInfo() if time.Since(start) > 1*time.Second { log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ Kernel: "windows", - OSVersion: osVersion, + OSVersion: si.OSVersion, Platform: "unknown", - OS: osName, + OS: si.OSName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), - KernelVersion: buildVersion, - NetworkAddresses: addrs, + KernelVersion: si.BuildVersion, SystemSerialNumber: si.SystemSerialNumber, SystemProductName: si.SystemProductName, SystemManufacturer: si.SystemManufacturer, Environment: si.Environment, } + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } else { + gio.NetworkAddresses = addrs + } + systemHostname, _ := os.Hostname() gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() @@ -69,120 +51,3 @@ func GetInfo(ctx context.Context) *Info { return gio } - -func sysInfo() (serialNumber string, productName string, manufacturer string) { - var err error - serialNumber, err = sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - productName, err = sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err = sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - return serialNumber, productName, manufacturer -} - -func getOSNameAndVersion() (string, string) { - var dst []Win32_OperatingSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - log.Error(err) - return "Windows", getBuildVersion() - } - - if len(dst) == 0 { - return "Windows", getBuildVersion() - } - - split := strings.Split(dst[0].Caption, " ") - - if len(split) <= 3 { - return "Windows", getBuildVersion() - } - - name := split[1] - version := split[2] - if split[2] == "Server" { - name = fmt.Sprintf("%s %s", split[1], split[2]) - version = split[3] - } - - return name, version -} - -func getBuildVersion() string { - k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) - if err != nil { - log.Error(err) - return "0.0.0.0" - } - defer func() { - deferErr := k.Close() - if deferErr != nil { - log.Error(deferErr) - } - }() - - major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") - if err != nil { - log.Error(err) - } - minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") - if err != nil { - log.Error(err) - } - build, _, err := k.GetStringValue("CurrentBuildNumber") - if err != nil { - log.Error(err) - } - // Update Build Revision - ubr, _, err := k.GetIntegerValue("UBR") - if err != nil { - log.Error(err) - } - ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) - return ver -} - -func sysNumber() (string, error) { - var dst []Win32_BIOS - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].SerialNumber, nil -} - -func sysProductName() (string, error) { - var dst []Win32_ComputerSystemProduct - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - // `ComputerSystemProduct` could be empty on some virtualized systems - if len(dst) < 1 { - return "unknown", nil - } - return dst[0].Name, nil -} - -func sysManufacturer() (string, error) { - var dst []Win32_ComputerSystem - query := wmi.CreateQuery(&dst, "") - err := wmi.Query(query, &dst) - if err != nil { - return "", err - } - return dst[0].Manufacturer, nil -} diff --git a/client/system/static_info.go b/client/system/static_info.go index f178ec932..12a2663a1 100644 --- a/client/system/static_info.go +++ b/client/system/static_info.go @@ -3,12 +3,7 @@ package system import ( - "context" "sync" - "time" - - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" ) var ( @@ -16,25 +11,26 @@ var ( once sync.Once ) -func updateStaticInfo() StaticInfo { +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment + + // Windows specific fields + OSName string + OSVersion string + BuildVersion string +} + +func updateStaticInfo() { once.Do(func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() - wg.Done() - }() - go func() { - staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) - wg.Done() - }() - go func() { - staticInfo.Environment.Platform = detect_platform.Detect(ctx) - wg.Done() - }() - wg.Wait() + staticInfo = newStaticInfo() }) +} + +func getStaticInfo() StaticInfo { + updateStaticInfo() return staticInfo } diff --git a/client/system/static_info_stub.go b/client/system/static_info_stub.go deleted file mode 100644 index faa3e700b..000000000 --- a/client/system/static_info_stub.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build android || freebsd || ios - -package system - -// updateStaticInfo returns an empty implementation for unsupported platforms -func updateStaticInfo() StaticInfo { - return StaticInfo{} -} diff --git a/client/system/static_info_update.go b/client/system/static_info_update.go new file mode 100644 index 000000000..af8b1e266 --- /dev/null +++ b/client/system/static_info_update.go @@ -0,0 +1,35 @@ +//go:build (linux && !android) || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +func newStaticInfo() StaticInfo { + si := StaticInfo{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + si.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + si.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + return si +} diff --git a/client/system/static_info_update_windows.go b/client/system/static_info_update_windows.go new file mode 100644 index 000000000..5f232c1de --- /dev/null +++ b/client/system/static_info_update_windows.go @@ -0,0 +1,184 @@ +package system + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "github.com/yusufpapurcu/wmi" + "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +type Win32_OperatingSystem struct { + Caption string +} + +type Win32_ComputerSystem struct { + Manufacturer string +} + +type Win32_ComputerSystemProduct struct { + Name string +} + +type Win32_BIOS struct { + SerialNumber string +} + +func newStaticInfo() StaticInfo { + si := StaticInfo{} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo() + wg.Done() + }() + wg.Add(1) + go func() { + si.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + wg.Add(1) + go func() { + si.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Add(1) + go func() { + si.OSName, si.OSVersion = getOSNameAndVersion() + wg.Done() + }() + wg.Add(1) + go func() { + si.BuildVersion = getBuildVersion() + wg.Done() + }() + wg.Wait() + return si +} + +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + +func sysNumber() (string, error) { + var dst []Win32_BIOS + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].SerialNumber, nil +} + +func sysProductName() (string, error) { + var dst []Win32_ComputerSystemProduct + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + // `ComputerSystemProduct` could be empty on some virtualized systems + if len(dst) < 1 { + return "unknown", nil + } + return dst[0].Name, nil +} + +func sysManufacturer() (string, error) { + var dst []Win32_ComputerSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + return "", err + } + return dst[0].Manufacturer, nil +} + +func getOSNameAndVersion() (string, string) { + var dst []Win32_OperatingSystem + query := wmi.CreateQuery(&dst, "") + err := wmi.Query(query, &dst) + if err != nil { + log.Error(err) + return "Windows", getBuildVersion() + } + + if len(dst) == 0 { + return "Windows", getBuildVersion() + } + + split := strings.Split(dst[0].Caption, " ") + + if len(split) <= 3 { + return "Windows", getBuildVersion() + } + + name := split[1] + version := split[2] + if split[2] == "Server" { + name = fmt.Sprintf("%s %s", split[1], split[2]) + version = split[3] + } + + return name, version +} + +func getBuildVersion() string { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) + if err != nil { + log.Error(err) + return "0.0.0.0" + } + defer func() { + deferErr := k.Close() + if deferErr != nil { + log.Error(deferErr) + } + }() + + major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber") + if err != nil { + log.Error(err) + } + minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber") + if err != nil { + log.Error(err) + } + build, _, err := k.GetStringValue("CurrentBuildNumber") + if err != nil { + log.Error(err) + } + // Update Build Revision + ubr, _, err := k.GetIntegerValue("UBR") + if err != nil { + log.Error(err) + } + ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr) + return ver +} From 5113c7094317a99aaf247720db21741c3327d088 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Sat, 6 Sep 2025 13:13:49 +0300 Subject: [PATCH 21/41] [management] Extends integration and peers manager (#4450) --- client/cmd/testutil_test.go | 3 +- client/internal/engine_test.go | 3 +- client/server/server_test.go | 3 +- go.mod | 2 +- go.sum | 4 +- management/internals/server/controllers.go | 6 +- management/server/peers/manager.go | 5 + management/server/peers/manager_mock.go | 15 +++ management/server/store/sql_store.go | 19 ++++ management/server/store/sql_store_test.go | 110 +++++++++++++++++++++ management/server/store/store.go | 1 + shared/management/client/client_test.go | 4 +- 12 files changed, 167 insertions(+), 8 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 3f08d12ba..99ccb1539 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -93,8 +93,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp permissionsManagerMock := permissions.NewMockManager(ctrl) peersmanager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, eventStore) + iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index afe4622b3..90c8cbc60 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1557,7 +1557,8 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri permissionsManager := permissions.NewManager(store) peersManager := peers.NewManager(store, permissionsManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, eventStore) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/server/server_test.go b/client/server/server_test.go index 493c8601a..87889cbce 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -299,8 +299,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve permissionsManagerMock := permissions.NewMockManager(ctrl) peersManager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, eventStore) + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/go.mod b/go.mod index 68730bf53..70e52875f 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f + github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 4783c47eb..3fdef5d08 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f h1:r1gnjw0TfkaDLSCmAE3g5N5ulcd5WpFHaGrqQomCXP4= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index f9023b204..984a56a39 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -20,7 +20,11 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.PeersManager(), s.EventStore()) + integratedPeerValidator, err := integrations.NewIntegratedValidator( + context.Background(), + s.PeersManager(), + s.SettingsManager(), + s.EventStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go index 50e36a880..cb135f4ac 100644 --- a/management/server/peers/manager.go +++ b/management/server/peers/manager.go @@ -18,6 +18,7 @@ type Manager interface { GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) GetPeerAccountID(ctx context.Context, peerID string) (string, error) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) } type managerImpl struct { @@ -61,3 +62,7 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) } + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} diff --git a/management/server/peers/manager_mock.go b/management/server/peers/manager_mock.go index b247a1752..994f8346b 100644 --- a/management/server/peers/manager_mock.go +++ b/management/server/peers/manager_mock.go @@ -79,3 +79,18 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) } + +// GetPeersByGroupIDs mocks base method. +func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs) + ret0, _ := ret[0].([]*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs. +func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 45561f950..027938320 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2847,3 +2847,22 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i } return nil } + +func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { + if len(groupIDs) == 0 { + return []*nbpeer.Peer{}, nil + } + + var peers []*nbpeer.Peer + peerIDsSubquery := s.db.Model(&types.GroupPeer{}). + Select("DISTINCT peer_id"). + Where("account_id = ? AND group_id IN ?", accountID, groupIDs) + + result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by group IDs") + } + + return peers, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 935b0a595..d40c4664c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3607,3 +3607,113 @@ func intToIPv4(n uint32) net.IP { binary.BigEndian.PutUint32(ip, n) return ip } + +func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group1ID := "test-group-1" + group2ID := "test-group-2" + emptyGroupID := "empty-group" + + peer1 := "cfefqs706sqkneg59g4g" + peer2 := "cfeg6sf06sqkneg59g50" + + tests := []struct { + name string + groupIDs []string + expectedPeers []string + expectedCount int + }{ + { + name: "retrieve peers from single group with multiple peers", + groupIDs: []string{group1ID}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from single group with one peer", + groupIDs: []string{group2ID}, + expectedPeers: []string{peer1}, + expectedCount: 1, + }, + { + name: "retrieve peers from multiple groups (with overlap)", + groupIDs: []string{group1ID, group2ID}, + expectedPeers: []string{peer1, peer2}, // should deduplicate + expectedCount: 2, + }, + { + name: "retrieve peers from existing 'All' group", + groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from empty group", + groupIDs: []string{emptyGroupID}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "retrieve peers from non-existing group", + groupIDs: []string{"non-existing-group"}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "mix of existing and non-existing groups", + groupIDs: []string{group1ID, "non-existing-group"}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + + groups := []*types.Group{ + { + ID: group1ID, + AccountID: accountID, + }, + { + ID: group2ID, + AccountID: accountID, + }, + } + require.NoError(t, store.CreateGroups(ctx, accountID, groups)) + + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID)) + + peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + + if tt.expectedCount > 0 { + actualPeerIDs := make([]string, len(peers)) + for i, peer := range peers { + actualPeerIDs[i] = peer.ID + } + assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs) + + // Verify all returned peers belong to the correct account + for _, peer := range peers { + assert.Equal(t, accountID, peer.AccountID) + } + } + }) + } +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 545549410..3c9d896b0 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -136,6 +136,7 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index b04cdd96a..becc10ded 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -86,7 +86,9 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { AnyTimes() peersManger := peers.NewManager(store, permissionsManagerMock) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, eventStore) + settingsManagerMock := settings.NewMockManager(ctrl) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) From 69d87343d2bbde0d30deaed7b5f2970827e93c5e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 8 Sep 2025 14:51:34 +0200 Subject: [PATCH 22/41] [client] Debug information for connection (#4439) Improve logging Print the exact time when the first WireGuard handshake occurs Print the steps for gathering system information --- client/iface/configurer/usp.go | 7 +++++++ client/internal/engine.go | 3 +-- client/internal/peer/wg_watcher.go | 13 ++++++++++--- client/system/info.go | 4 ++++ client/system/info_windows.go | 1 - 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 171458e38..945f1a162 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) { if err != nil { return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) } + + // If sec is 0 (Unix epoch), return zero time instead + // This indicates no handshake has occurred + if sec == 0 { + return time.Time{}, nil + } + return time.Unix(sec, 0), nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index ca01bfd14..61ff41600 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -949,7 +949,6 @@ func (e *Engine) receiveManagementEvents() { e.config.LazyConnectionEnabled, ) - // err = e.mgmClient.Sync(info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. @@ -960,7 +959,7 @@ func (e *Engine) receiveManagementEvents() { } log.Debugf("stopped receiving updates from Management Service") }() - log.Debugf("connecting to Management Service updates stream") + log.Infof("connecting to Management Service updates stream") } func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 218872c15..0ed200fda 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,9 +30,10 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + enabledTime time.Time } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() + w.enabledTime = time.Now() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") @@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex onDisconnectedFn() return } + if lastHandshake.IsZero() { + elapsed := handshake.Sub(w.enabledTime).Seconds() + w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + } + lastHandshake = *handshake resetTime := time.Until(handshake.Add(checkPeriod)) diff --git a/client/system/info.go b/client/system/info.go index ceb1682f3..a180be4c0 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/shared/management/proto" @@ -172,6 +173,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { + log.Debugf("gathering system information with checks: %d", len(checks)) processCheckPaths := make([]string, 0) for _, check := range checks { processCheckPaths = append(processCheckPaths, check.GetFiles()...) @@ -181,9 +183,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro if err != nil { return nil, err } + log.Debugf("gathering process check information completed") info := GetInfo(ctx) info.Files = files + log.Debugf("all system information gathered successfully") return info, nil } diff --git a/client/system/info_windows.go b/client/system/info_windows.go index e67356f57..d7f8f30aa 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -48,6 +48,5 @@ func GetInfo(ctx context.Context) *Info { gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) - return gio } From dba7ef667d6921e4d7b0381df28a3b5ae8d3db16 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 8 Sep 2025 10:03:56 -0300 Subject: [PATCH 23/41] [misc] Remove aur support and start service on ostree (#4461) * Remove aur support and start service on ostree The aur installation was adding many packages and installing more than just the client. For now is best to remove it and rely on binary install Some users complained about ostree installation not starting the client, we add two explicit commands to it * use ${SUDO} * fix if closure --- release_files/install.sh | 43 +++------------------------------------- 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index 856d332cb..5d5349ec4 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -130,36 +130,6 @@ repo_gpgcheck=1 EOF } -install_aur_package() { - INSTALL_PKGS="git base-devel go" - REMOVE_PKGS="" - - # Check if dependencies are installed - for PKG in $INSTALL_PKGS; do - if ! pacman -Q "$PKG" > /dev/null 2>&1; then - # Install missing package(s) - ${SUDO} pacman -S "$PKG" --noconfirm - - # Add installed package for clean up later - REMOVE_PKGS="$REMOVE_PKGS $PKG" - fi - done - - # Build package from AUR - cd /tmp && git clone https://aur.archlinux.org/netbird.git - cd netbird && makepkg -sri --noconfirm - - if ! $SKIP_UI_APP; then - cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git - cd netbird-ui && makepkg -sri --noconfirm - fi - - if [ -n "$REMOVE_PKGS" ]; then - # Clean up the installed packages - ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm - fi -} - prepare_tun_module() { # Create the necessary file structure for /dev/net/tun if [ ! -c /dev/net/tun ]; then @@ -276,12 +246,9 @@ install_netbird() { if ! $SKIP_UI_APP; then ${SUDO} rpm-ostree -y install netbird-ui fi - ;; - pacman) - ${SUDO} pacman -Syy - install_aur_package - # in-line with the docs at https://wiki.archlinux.org/title/Netbird - ${SUDO} systemctl enable --now netbird@main.service + # ensure the service is started after install + ${SUDO} netbird service install || true + ${SUDO} netbird service start || true ;; pkg) # Check if the package is already installed @@ -458,11 +425,7 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" - elif [ -x "$(command -v pacman)" ]; then - PACKAGE_MANAGER="pacman" - echo "The installation will be performed using pacman package manager" fi - else echo "Unable to determine OS type from /etc/os-release" exit 1 From 7aef0f67df4b5974fc08b2808b8de3998e8a9de9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 8 Sep 2025 18:42:42 +0200 Subject: [PATCH 24/41] [client] Implement environment variable handling for Android (#4440) Some features can only be manipulated via environment variables. With this PR, environment variables can be managed from Android. --- client/android/client.go | 18 ++++++++++++++++-- client/android/env_list.go | 32 ++++++++++++++++++++++++++++++++ client/internal/peer/conn.go | 3 +-- client/internal/peer/env.go | 14 ++++++++++++++ 4 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 client/android/env_list.go create mode 100644 client/internal/peer/env.go diff --git a/client/android/client.go b/client/android/client.go index c05246569..4b4fcc9be 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,6 +4,7 @@ package android import ( "context" + "os" "slices" "sync" @@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } + } +} diff --git a/client/android/env_list.go b/client/android/env_list.go new file mode 100644 index 000000000..04122300a --- /dev/null +++ b/client/android/env_list.go @@ -0,0 +1,32 @@ +package android + +import "github.com/netbirdio/netbird/client/internal/peer" + +var ( + // EnvKeyNBForceRelay Exported for Android java client + EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay +) + +// EnvList wraps a Go map for export to Java +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 224a8144c..86e4596d4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -6,7 +6,6 @@ import ( "math/rand" "net" "net/netip" - "os" "runtime" "sync" "time" @@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) - if os.Getenv("NB_FORCE_RELAY") != "true" { + if !isForceRelayed() { conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go new file mode 100644 index 000000000..32a458d00 --- /dev/null +++ b/client/internal/peer/env.go @@ -0,0 +1,14 @@ +package peer + +import ( + "os" + "strings" +) + +const ( + EnvKeyNBForceRelay = "NB_FORCE_RELAY" +) + +func isForceRelayed() bool { + return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") +} From 9e81e782e50b734f68b107ccb0e2ca37c74d6711 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 11 Sep 2025 10:08:54 +0200 Subject: [PATCH 25/41] [client] Fix/v4 stun routing (#4430) Deduplicate STUN package sending. Originally, because every peer shared the same UDP address, the library could not distinguish which STUN message was associated with which candidate. As a result, the Pion library responded from all candidates for every STUN message. --- client/iface/bind/ice_bind.go | 13 +- client/iface/bind/udp_mux_ios.go | 7 - client/iface/device.go | 4 +- client/iface/device/device_android.go | 5 +- client/iface/device/device_darwin.go | 5 +- client/iface/device/device_ios.go | 5 +- client/iface/device/device_kernel_unix.go | 12 +- client/iface/device/device_netstack.go | 5 +- client/iface/device/device_usp_unix.go | 5 +- client/iface/device/device_windows.go | 5 +- client/iface/device_android.go | 4 +- client/iface/iface.go | 6 +- .../udp_muxed_conn.go => udpmux/conn.go} | 17 +- client/iface/udpmux/doc.go | 64 ++++++ .../iface/{bind/udp_mux.go => udpmux/mux.go} | 193 +++++++++++------- .../mux_generic.go} | 4 +- client/iface/udpmux/mux_ios.go | 7 + .../universal.go} | 14 +- client/internal/engine.go | 8 +- client/internal/engine_test.go | 9 +- client/internal/iface_common.go | 4 +- client/internal/peer/worker_ice.go | 78 +++---- go.mod | 2 +- go.sum | 4 +- 24 files changed, 301 insertions(+), 179 deletions(-) delete mode 100644 client/iface/bind/udp_mux_ios.go rename client/iface/{bind/udp_muxed_conn.go => udpmux/conn.go} (95%) create mode 100644 client/iface/udpmux/doc.go rename client/iface/{bind/udp_mux.go => udpmux/mux.go} (65%) rename client/iface/{bind/udp_mux_generic.go => udpmux/mux_generic.go} (85%) create mode 100644 client/iface/udpmux/mux_ios.go rename client/iface/{bind/udp_mux_universal.go => udpmux/universal.go} (97%) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 359d2129b..b74f90d6c 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -15,6 +15,7 @@ import ( "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -44,7 +45,7 @@ type ICEBind struct { RecvChan chan RecvMessage transportNet transport.Net - filterFn FilterFn + filterFn udpmux.FilterFn endpoints map[netip.Addr]net.Conn endpointsMu sync.Mutex // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a @@ -54,13 +55,13 @@ type ICEBind struct { closed bool muUDPMux sync.Mutex - udpMux *UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault address wgaddr.Address mtu uint16 activityRecorder *ActivityRecorder } -func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, @@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder { } // GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { +func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() if s.udpMux == nil { @@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go deleted file mode 100644 index db0249d11..000000000 --- a/client/iface/bind/udp_mux_ios.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build ios - -package bind - -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { - // iOS doesn't support nbnet hooks, so this is a no-op -} diff --git a/client/iface/device.go b/client/iface/device.go index ca6dda2c2..921f0ea98 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -7,14 +7,14 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index fe3b9f82e..a731684cc 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type WGTunDevice struct { name string device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string } return t.configurer, nil } -func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index cce9d42df..390efe088 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -26,7 +27,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 168985b5e..96e4c8bcf 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -28,7 +29,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 00a72bcc6..2ef6f6b22 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -12,8 +12,8 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/sharedsock" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,9 +31,9 @@ type TunKernelDevice struct { link *wgLink udpMuxConn net.PacketConn - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault - filterFn bind.FilterFn + filterFn udpmux.FilterFn } func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { @@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { return configurer, nil } -func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -106,14 +106,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { udpConn = nbnet.WrapPacketConn(rawSock) } - bindParams := bind.UniversalUDPMuxParams{ + bindParams := udpmux.UniversalUDPMuxParams{ UDPConn: udpConn, Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, MTU: t.mtu, } - mux := bind.NewUniversalUDPMuxDefault(bindParams) + mux := udpmux.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) t.udpMuxConn = rawSock t.udpMux = mux diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f41331ff7..2fcc74809 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -10,6 +10,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -26,7 +27,7 @@ type TunNetstackDevice struct { device *device.Device filteredDevice *FilteredDevice nsTun *nbnetstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer net *netstack.Net @@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 8d30112ae..4cdd70a32 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -25,7 +26,7 @@ type USPDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index de258868f..f1023bc0a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type TunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 39b5c28ae..4649b8b97 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -5,14 +5,14 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/iface.go b/client/iface/iface.go index 9a42223a1..609572561 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -16,9 +16,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -61,7 +61,7 @@ type WGIFaceOpts struct { MTU uint16 MobileArgs *device.MobileIFaceArguments TransportNet transport.Net - FilterFn bind.FilterFn + FilterFn udpmux.FilterFn DisableDNS bool } @@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface { // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/bind/udp_muxed_conn.go b/client/iface/udpmux/conn.go similarity index 95% rename from client/iface/bind/udp_muxed_conn.go rename to client/iface/udpmux/conn.go index 7cacf1c31..3aa40caeb 100644 --- a/client/iface/bind/udp_muxed_conn.go +++ b/client/iface/udpmux/conn.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements @@ -16,11 +16,12 @@ import ( ) type udpMuxedConnParams struct { - Mux *UDPMuxDefault - AddrPool *sync.Pool - Key string - LocalAddr net.Addr - Logger logging.LeveledLogger + Mux *SingleSocketUDPMux + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger + CandidateID string } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag @@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error { return err } +func (c *udpMuxedConn) GetCandidateID() string { + return c.params.CandidateID +} + func (c *udpMuxedConn) isClosed() bool { select { case <-c.closedChan: diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go new file mode 100644 index 000000000..27e5e43bc --- /dev/null +++ b/client/iface/udpmux/doc.go @@ -0,0 +1,64 @@ +// Package udpmux provides a custom implementation of a UDP multiplexer +// that allows multiple logical ICE connections to share a single underlying +// UDP socket. This is based on Pion's ICE library, with modifications for +// NetBird's requirements. +// +// # Background +// +// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity +// Establishment) is responsible for discovering candidate network paths +// and maintaining connectivity between peers. Each ICE connection +// normally requires a dedicated UDP socket. However, using one socket +// per candidate can be inefficient and difficult to manage. +// +// This package introduces SingleSocketUDPMux, which allows multiple ICE +// candidate connections (muxed connections) to share a single UDP socket. +// It handles demultiplexing of packets based on ICE ufrag values, STUN +// attributes, and candidate IDs. +// +// # Usage +// +// The typical flow is: +// +// 1. Create a UDP socket (net.PacketConn). +// 2. Construct Params with the socket and optional logger/net stack. +// 3. Call NewSingleSocketUDPMux(params). +// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID) +// to obtain a logical PacketConn. +// 5. Use the returned PacketConn just like a normal UDP connection. +// +// # STUN Message Routing Logic +// +// When a STUN packet arrives, the mux decides which connection should +// receive it using this routing logic: +// +// Primary Routing: Candidate Pair ID +// - Extract the candidate pair ID from the STUN message using +// ice.CandidatePairIDFromSTUN(msg) +// - The target candidate is the locally generated candidate that +// corresponds to the connection that should handle this STUN message +// - If found, use the target candidate ID to lookup the specific +// connection in candidateConnMap +// - Route the message directly to that connection +// +// Fallback Routing: Broadcasting +// When candidate pair ID is not available or lookup fails: +// - Collect connections from addressMap based on source address +// - Find connection using username attribute (ufrag) from STUN message +// - Remove duplicate connections from the list +// - Send the STUN message to all collected connections +// +// # Peer Reflexive Candidate Discovery +// +// When a remote peer sends a STUN message from an unknown source address +// (from a candidate that has not been exchanged via signal), the ICE +// library will: +// - Generate a new peer reflexive candidate for this source address +// - Extract or assign a candidate ID based on the STUN message attributes +// - Create a mapping between the new peer reflexive candidate ID and +// the appropriate local connection +// +// This discovery mechanism ensures that STUN messages from newly discovered +// peer reflexive candidates can be properly routed to the correct local +// connection without requiring fallback broadcasting. +package udpmux diff --git a/client/iface/bind/udp_mux.go b/client/iface/udpmux/mux.go similarity index 65% rename from client/iface/bind/udp_mux.go rename to client/iface/udpmux/mux.go index db7494405..319724926 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/udpmux/mux.go @@ -1,4 +1,4 @@ -package bind +package udpmux import ( "fmt" @@ -22,9 +22,9 @@ import ( const receiveMTU = 8192 -// UDPMuxDefault is an implementation of the interface -type UDPMuxDefault struct { - params UDPMuxParams +// SingleSocketUDPMux is an implementation of the interface +type SingleSocketUDPMux struct { + params Params closedChan chan struct{} closeOnce sync.Once @@ -32,6 +32,9 @@ type UDPMuxDefault struct { // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn + // candidateConnMap maps local candidate IDs to their corresponding connection. + candidateConnMap map[string]*udpMuxedConn + addressMapMu sync.RWMutex addressMap map[string][]*udpMuxedConn @@ -46,8 +49,8 @@ type UDPMuxDefault struct { const maxAddrSize = 512 -// UDPMuxParams are parameters for UDPMux. -type UDPMuxParams struct { +// Params are parameters for UDPMux. +type Params struct { Logger logging.LeveledLogger UDPConn net.PacketConn @@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool { return true } -// NewUDPMuxDefault creates an implementation of UDPMux -func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { +// NewSingleSocketUDPMux creates an implementation of UDPMux +func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { if params.Logger == nil { params.Logger = getLogger() } - mux := &UDPMuxDefault{ - addressMap: map[string][]*udpMuxedConn{}, - params: params, - connsIPv4: make(map[string]*udpMuxedConn), - connsIPv6: make(map[string]*udpMuxedConn), - closedChan: make(chan struct{}, 1), + mux := &SingleSocketUDPMux{ + addressMap: map[string][]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + candidateConnMap: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address @@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return mux } -func (m *UDPMuxDefault) updateLocalAddresses() { +func (m *SingleSocketUDPMux) updateLocalAddresses() { var localAddrsForUnspecified []net.Addr if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) } else if ok && addr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + // with SingleSocketUDPMux, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { @@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() { m.mu.Unlock() } -// LocalAddr returns the listening address of this UDPMuxDefault -func (m *UDPMuxDefault) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this SingleSocketUDPMux +func (m *SingleSocketUDPMux) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on -func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { +func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { m.updateLocalAddresses() m.mu.Lock() @@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network address // creates the connection if an existing one can't be found -func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { +func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { // don't check addr for mux using unspecified address m.mu.Lock() lenLocalAddrs := len(m.localAddrsForUnspecified) @@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return conn, nil } - c := m.createMuxedConn(ufrag) + c := m.createMuxedConn(ufrag, candidateID) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() + m.candidateConnMap[candidateID] = c + if isIPv6 { m.connsIPv6[ufrag] = c } else { @@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er } // RemoveConnByUfrag stops and removes the muxed packet connection -func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { +func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock @@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } m.mu.Unlock() @@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { } // IsClosed returns true if the mux had been closed -func (m *UDPMuxDefault) IsClosed() bool { +func (m *SingleSocketUDPMux) IsClosed() bool { select { case <-m.closedChan: return true @@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool { } // Close the mux, no further connections could be created -func (m *UDPMuxDefault) Close() error { +func (m *SingleSocketUDPMux) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() @@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error { return err } -func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { +func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } @@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } -func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { +func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ - Mux: m, - Key: key, - AddrPool: m.pool, - LocalAddr: m.LocalAddr(), - Logger: m.params.Logger, + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + CandidateID: candidateID, }) return c } // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library -func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { - +func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { remoteAddr, ok := addr.(*net.UDPAddr) if !ok { return fmt.Errorf("underlying PacketConn did not return a UDPAddr") } - // If we have already seen this address dispatch to the appropriate destination - // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one - // muxed connection - one for the SRFLX candidate and the other one for the HOST one. - // We will then forward STUN packets to each of these connections. - m.addressMapMu.RLock() + // Try to route to specific candidate connection first + if conn := m.findCandidateConnection(msg); conn != nil { + return conn.writePacket(msg.Raw, remoteAddr) + } + + // Fallback: route to all possible connections + return m.forwardToAllConnections(msg, addr, remoteAddr) +} + +// findCandidateConnection attempts to find the specific connection for a STUN message +func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn { + candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg) + if err != nil { + return nil + } else if !ok { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()] + if !exists { + return nil + } + return conn +} + +// forwardToAllConnections forwards STUN message to all relevant connections +func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error { var destinationConnList []*udpMuxedConn + + // Add connections from address map + m.addressMapMu.RLock() if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } m.addressMapMu.RUnlock() - var isIPv6 bool - if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { - isIPv6 = true + if conn, ok := m.findConnectionByUsername(msg, addr); ok { + // If we have already seen this address dispatch to the appropriate destination + // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one + // muxed connection - one for the SRFLX candidate and the other one for the HOST one. + // We will then forward STUN packets to each of these connections. + if !m.connectionExists(conn, destinationConnList) { + destinationConnList = append(destinationConnList, conn) + } } - // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. - // However, we can take a username attribute from the STUN message which contains ufrag. - // We can use ufrag to identify the destination conn to route packet to. - attr, stunAttrErr := msg.Get(stun.AttrUsername) - if stunAttrErr == nil { - ufrag := strings.Split(string(attr), ":")[0] - - m.mu.Lock() - destinationConn := m.connsIPv4[ufrag] - if isIPv6 { - destinationConn = m.connsIPv6[ufrag] - } - - if destinationConn != nil { - exists := false - for _, conn := range destinationConnList { - if conn.params.Key == destinationConn.params.Key { - exists = true - break - } - } - if !exists { - destinationConnList = append(destinationConnList, destinationConn) - } - } - m.mu.Unlock() - } - - // Forward STUN packets to each destination connections even thought the STUN packet might not belong there. - // It will be discarded by the further ICE candidate logic if so. + // Forward to all found connections for _, conn := range destinationConnList { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { log.Errorf("could not write packet: %v", err) } } - return nil } -func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { +// findConnectionByUsername finds connection using username attribute from STUN message +func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) { + attr, err := msg.Get(stun.AttrUsername) + if err != nil { + return nil, false + } + + ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := isIPv6Address(addr) + + m.mu.Lock() + defer m.mu.Unlock() + + return m.getConn(ufrag, isIPv6) +} + +// connectionExists checks if a connection already exists in the list +func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool { + for _, conn := range conns { + if conn.params.Key == target.params.Key { + return true + } + } + return false +} + +func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { @@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o return } +func isIPv6Address(addr net.Addr) bool { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + return udpAddr.IP.To4() == nil + } + return false +} + type bufferHolder struct { buf []byte } diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/udpmux/mux_generic.go similarity index 85% rename from client/iface/bind/udp_mux_generic.go rename to client/iface/udpmux/mux_generic.go index 63f786d2b..cf3043be0 100644 --- a/client/iface/bind/udp_mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -1,12 +1,12 @@ //go:build !ios -package bind +package udpmux import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { conn.RemoveAddress(addr) diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go new file mode 100644 index 000000000..4cf211d8f --- /dev/null +++ b/client/iface/udpmux/mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package udpmux + +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/udpmux/universal.go similarity index 97% rename from client/iface/bind/udp_mux_universal.go rename to client/iface/udpmux/universal.go index a1f517dcd..43bfedaaa 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/udpmux/universal.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. @@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { - *UDPMuxDefault + *SingleSocketUDPMux params UniversalUDPMuxParams // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents @@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - udpMuxParams := UDPMuxParams{ + udpMuxParams := Params{ Logger: params.Logger, UDPConn: m.params.UDPConn, Net: m.params.Net, } - m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) return m } @@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { - return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { + return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) } // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. @@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A } return nil } - return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) + return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. diff --git a/client/internal/engine.go b/client/internal/engine.go index 61ff41600..9dc744434 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -29,9 +29,9 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" @@ -166,7 +166,7 @@ type Engine struct { wgInterface WGIface - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -461,7 +461,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), } @@ -1326,7 +1326,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), }, diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 90c8cbc60..4d2e81f43 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,10 +26,11 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" @@ -84,7 +85,7 @@ type MockWGIface struct { NameFunc func() string AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface - UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error @@ -134,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface { return m.ToInterfaceFunc() } -func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } @@ -413,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) + engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index bf96153ea..690fdb7cc 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -9,9 +9,9 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -24,7 +24,7 @@ type wgIfaceBase interface { Name() string Address() wgaddr.Address ToInterface() *net.Interface - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 896c55b6c..4e85ba0fa 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -9,11 +9,10 @@ import ( "time" "github.com/pion/ice/v4" - "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -55,10 +54,6 @@ type WorkerICE struct { sessionID ICESessionID muxAgent sync.Mutex - StunTurn []*stun.URI - - sentExtraSrflx bool - localUfrag string localPwd string @@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.muxAgent.Unlock() return } - w.sentExtraSrflx = false w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true @@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA w.log.Errorf("error while handling remote candidate") return } + + if shouldAddExtraCandidate(candidate) { + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + + if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil { + w.log.Errorf("error while handling remote candidate") + return + } + } } func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { @@ -327,7 +336,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) return } - mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault) if !ok { w.log.Warn("invalid udp mux conversion") return @@ -354,26 +363,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() - - if !w.shouldSendExtraSrflxCandidate(candidate) { - return - } - - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - w.log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - w.sentExtraSrflx = true - - go func() { - err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) - if err != nil { - w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) - } - }() } func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { @@ -424,22 +413,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia } } -func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { - if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { - return true - } - return false -} - func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { - isControlling := w.config.LocalKey > w.config.Key - if isControlling { - return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + if isController(w.config) { + return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } } +func shouldAddExtraCandidate(candidate ice.Candidate) bool { + if candidate.Type() != ice.CandidateTypeServerReflexive { + return false + } + + if candidate.Port() == candidate.RelatedAddress().Port { + return false + } + + // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates + // in newer version we generate locally the extra candidate + if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok { + return false + } + return true +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -455,6 +453,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive } for _, e := range candidate.Extensions() { + // overwrite the original candidate ID with the new one to avoid candidate duplication + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = candidate.ID() + } if err := ec.AddExtension(e); err != nil { return nil, err } diff --git a/go.mod b/go.mod index 70e52875f..23aa45277 100644 --- a/go.mod +++ b/go.mod @@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 +replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index 3fdef5d08..7096be3fe 100644 --- a/go.sum +++ b/go.sum @@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= -github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= From 47e64d72dbe69ac3245e7c7140e2db868c8c8782 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 11 Sep 2025 16:21:09 +0200 Subject: [PATCH 26/41] [client] Fix client status check (#4474) The client status is not enough to protect the RPC calls from concurrency issues, because it is handled internally in the client in an asynchronous way. --- client/internal/connect.go | 7 +- client/server/server.go | 408 ++++++++++++++++++----------------- client/server/server_test.go | 1 + 3 files changed, 213 insertions(+), 203 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index f20b8d361..33cd4b4a1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -280,15 +280,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } - log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) if runningChan != nil { - select { - case runningChan <- struct{}{}: - default: - } + close(runningChan) + runningChan = nil } <-engineCtx.Done() diff --git a/client/server/server.go b/client/server/server.go index d89c7ce91..fae342f78 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -65,6 +65,8 @@ type Server struct { mutex sync.Mutex config *profilemanager.Config proto.UnimplementedDaemonServiceServer + clientRunning bool // protected by mutex + clientRunningChan chan struct{} connectClient *internal.ConnectClient @@ -103,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -172,8 +175,12 @@ func (s *Server) Start() error { return nil } - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) - + if s.clientRunning { + return nil + } + s.clientRunning = true + s.clientRunningChan = make(chan struct{}, 1) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan) return nil } @@ -204,12 +211,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - runningChan chan struct{}, -) { - backOff := getConnectWithBackoff(ctx) - retryStarted := false +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) { + defer func() { + s.mutex.Lock() + s.clientRunning = false + s.mutex.Unlock() + }() + if s.config.DisableAutoConnect { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + log.Debugf("run client connection exited with error: %v", err) + } + log.Tracef("client connection exited") + return + } + + backOff := getConnectWithBackoff(ctx) go func() { t := time.NewTicker(24 * time.Hour) for { @@ -218,91 +235,34 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage t.Stop() return case <-t.C: - if retryStarted { - - mgmtState := statusRecorder.GetManagementState() - signalState := statusRecorder.GetSignalState() - if mgmtState.Connected && signalState.Connected { - log.Tracef("resetting status") - retryStarted = false - } else { - log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) - } + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + backOff.Reset() + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) } } } }() runOperation := func() error { - log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - - err := s.connectClient.Run(runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + return err } - if config.DisableAutoConnect { - return backoff.Permanent(err) - } - - if !retryStarted { - retryStarted = true - backOff.Reset() - } - - log.Tracef("client connection exited") - return fmt.Errorf("client connection exited") + log.Tracef("client connection exited gracefully, do not need to retry") + return nil } - err := backoff.Retry(runOperation, backOff) - if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { - log.Errorf("received an error when trying to connect: %v", err) - } else { - log.Tracef("retry canceled") + if err := backoff.Retry(runOperation, backOff); err != nil { + log.Errorf("operation failed: %v", err) } } -// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries -func getConnectWithBackoff(ctx context.Context) backoff.BackOff { - initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) - maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) - maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) - multiplier := defaultRetryMultiplier - - if envValue := os.Getenv(retryMultiplierVar); envValue != "" { - // parse the multiplier from the environment variable string value to float64 - value, err := strconv.ParseFloat(envValue, 64) - if err != nil { - log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) - } else { - multiplier = value - } - } - - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: initialInterval, - RandomizationFactor: 1, - Multiplier: multiplier, - MaxInterval: maxInterval, - MaxElapsedTime: maxElapsedTime, // 14 days - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// parseEnvDuration parses the environment variable and returns the duration -func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { - if envValue := os.Getenv(envVar); envValue != "" { - if duration, err := time.ParseDuration(envValue); err == nil { - return duration - } - log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) - } - return defaultDuration -} - // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { var status internal.StatusType @@ -716,11 +676,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) + if !s.clientRunning { + s.clientRunning = true + s.clientRunningChan = make(chan struct{}, 1) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan) + } for { select { - case <-runningChan: + case <-s.clientRunningChan: s.isSessionActive.Store(true) return &proto.UpResponse{}, nil case <-callerCtx.Done(): @@ -1127,6 +1090,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p }, nil } +// AddProfile adds a new profile to the daemon. +func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + + if msg.ProfileName == "" || msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") + } + + if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to create profile: %v", err) + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + return &proto.AddProfileResponse{}, nil +} + +// RemoveProfile removes a profile from the daemon. +func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { + return nil, err + } + + if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { + log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) + } + + if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to remove profile: %v", err) + return nil, fmt.Errorf("failed to remove profile: %w", err) + } + + return &proto.RemoveProfileResponse{}, nil +} + +// ListProfiles lists all profiles in the daemon. +func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") + } + + profiles, err := s.profileManager.ListProfiles(msg.Username) + if err != nil { + log.Errorf("failed to list profiles: %v", err) + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + response := &proto.ListProfilesResponse{ + Profiles: make([]*proto.Profile, len(profiles)), + } + for i, profile := range profiles { + response.Profiles[i] = &proto.Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + } + } + + return response, nil +} + +// GetActiveProfile returns the active profile in the daemon. +func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + activeProfile, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + return &proto.GetActiveProfileResponse{ + ProfileName: activeProfile.Name, + Username: activeProfile.Username, + }, nil +} + +// GetFeatures returns the features supported by the daemon. +func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + features := &proto.GetFeaturesResponse{ + DisableProfiles: s.checkProfilesDisabled(), + DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + } + + return features, nil +} + +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { + log.Tracef("running client connection") + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) + if err := s.connectClient.Run(runningChan); err != nil { + return err + } + return nil +} + +func (s *Server) checkProfilesDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.profilesDisabled { + return true + } + + return false +} + +func (s *Server) checkUpdateSettingsDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.updateSettingsDisabled { + return true + } + + return false +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() @@ -1138,6 +1229,45 @@ func (s *Server) onSessionExpire() { } } +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus := proto.FullStatus{ ManagementState: &proto.ManagementState{}, @@ -1252,121 +1382,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// AddProfile adds a new profile to the daemon. -func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.checkProfilesDisabled() { - return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) - } - - if msg.ProfileName == "" || msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") - } - - if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to create profile: %v", err) - return nil, fmt.Errorf("failed to create profile: %w", err) - } - - return &proto.AddProfileResponse{}, nil -} - -// RemoveProfile removes a profile from the daemon. -func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { - return nil, err - } - - if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { - log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) - } - - if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to remove profile: %v", err) - return nil, fmt.Errorf("failed to remove profile: %w", err) - } - - return &proto.RemoveProfileResponse{}, nil -} - -// ListProfiles lists all profiles in the daemon. -func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") - } - - profiles, err := s.profileManager.ListProfiles(msg.Username) - if err != nil { - log.Errorf("failed to list profiles: %v", err) - return nil, fmt.Errorf("failed to list profiles: %w", err) - } - - response := &proto.ListProfilesResponse{ - Profiles: make([]*proto.Profile, len(profiles)), - } - for i, profile := range profiles { - response.Profiles[i] = &proto.Profile{ - Name: profile.Name, - IsActive: profile.IsActive, - } - } - - return response, nil -} - -// GetActiveProfile returns the active profile in the daemon. -func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - activeProfile, err := s.profileManager.GetActiveProfileState() - if err != nil { - log.Errorf("failed to get active profile state: %v", err) - return nil, fmt.Errorf("failed to get active profile state: %w", err) - } - - return &proto.GetActiveProfileResponse{ - ProfileName: activeProfile.Name, - Username: activeProfile.Username, - }, nil -} - -// GetFeatures returns the features supported by the daemon. -func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - features := &proto.GetFeaturesResponse{ - DisableProfiles: s.checkProfilesDisabled(), - DisableUpdateSettings: s.checkUpdateSettingsDisabled(), - } - - return features, nil -} - -func (s *Server) checkProfilesDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.profilesDisabled { - return true - } - - return false -} - -func (s *Server) checkUpdateSettingsDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.updateSettingsDisabled { - return true - } - - return false -} diff --git a/client/server/server_test.go b/client/server/server_test.go index 87889cbce..45a1aa5c7 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" From cf7f6c355f713e83cf171b79e08dac60b316e4fd Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 11 Sep 2025 22:20:10 +0300 Subject: [PATCH 27/41] [misc] Remove default zitadel admin user in deployment script (#4482) * Delete default zitadel-admin user during initialization Signed-off-by: bcmmbaga * Refactor Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- .../getting-started-with-zitadel.sh | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2d7c65cbe..cfec1000e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -328,6 +328,45 @@ delete_auto_service_user() { echo "$PARSED_RESPONSE" } +delete_default_zitadel_admin() { + INSTANCE_URL=$1 + PAT=$2 + + # Search for the default zitadel-admin user + RESPONSE=$( + curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + -d '{ + "queries": [ + { + "userNameQuery": { + "userName": "zitadel-admin@", + "method": "TEXT_QUERY_METHOD_STARTS_WITH" + } + } + ] + }' + ) + + DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty') + + if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then + echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID" + + RESPONSE=$( + curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + ) + PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"') + handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE" + + else + echo "Default zitadel-admin user not found: $RESPONSE" + fi +} + init_zitadel() { echo -e "\nInitializing Zitadel with NetBird's applications\n" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" @@ -346,6 +385,9 @@ init_zitadel() { echo -n "Waiting for Zitadel to become ready " wait_api "$INSTANCE_URL" "$PAT" + echo "Deleting default zitadel-admin user..." + delete_default_zitadel_admin "$INSTANCE_URL" "$PAT" + # create the zitadel project echo "Creating new zitadel project" PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") From 0c6f671a7c8ab9befb9580c9297a1b0de05286cb Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 12 Sep 2025 09:31:03 +0200 Subject: [PATCH 28/41] Refactor healthcheck sender and receiver to use configurable options (#4433) --- .github/workflows/golang-test-linux.yml | 2 +- shared/relay/client/manager.go | 7 +- shared/relay/client/picker.go | 18 +++--- shared/relay/client/picker_test.go | 10 +-- shared/relay/healthcheck/env.go | 24 +++++++ shared/relay/healthcheck/env_test.go | 36 +++++++++++ shared/relay/healthcheck/receiver.go | 32 +++++++-- shared/relay/healthcheck/receiver_test.go | 79 ++++++----------------- shared/relay/healthcheck/sender.go | 76 ++++++++++++---------- shared/relay/healthcheck/sender_test.go | 78 ++++++---------------- 10 files changed, 186 insertions(+), 176 deletions(-) create mode 100644 shared/relay/healthcheck/env.go create mode 100644 shared/relay/healthcheck/env_test.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index f7b4e238f..ba36c013b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index a40343fb1..6220e7f6b 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -78,9 +78,10 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin tokenStore: tokenStore, mtu: mtu, serverPicker: &ServerPicker{ - TokenStore: tokenStore, - PeerID: peerID, - MTU: mtu, + TokenStore: tokenStore, + PeerID: peerID, + MTU: mtu, + ConnectionTimeout: defaultConnectionTimeout, }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go index b6c7b5e8a..39d0ba072 100644 --- a/shared/relay/client/picker.go +++ b/shared/relay/client/picker.go @@ -13,11 +13,8 @@ import ( ) const ( - maxConcurrentServers = 7 -) - -var ( - connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 + defaultConnectionTimeout = 30 * time.Second ) type connResult struct { @@ -27,14 +24,15 @@ type connResult struct { } type ServerPicker struct { - TokenStore *auth.TokenStore - ServerURLs atomic.Value - PeerID string - MTU uint16 + TokenStore *auth.TokenStore + ServerURLs atomic.Value + PeerID string + MTU uint16 + ConnectionTimeout time.Duration } func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { - ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout) defer cancel() totalServers := len(sp.ServerURLs.Load().([]string)) diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go index 28167c5ce..fb3fa7375 100644 --- a/shared/relay/client/picker_test.go +++ b/shared/relay/client/picker_test.go @@ -8,15 +8,15 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { - connectionTimeout = 5 * time.Second - + timeout := 5 * time.Second sp := ServerPicker{ - TokenStore: nil, - PeerID: "test", + TokenStore: nil, + PeerID: "test", + ConnectionTimeout: timeout, } sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) - ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) + ctx, cancel := context.WithTimeout(context.Background(), timeout+1) defer cancel() go func() { diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go new file mode 100644 index 000000000..2b584c195 --- /dev/null +++ b/shared/relay/healthcheck/env.go @@ -0,0 +1,24 @@ +package healthcheck + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" +) + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go new file mode 100644 index 000000000..2e14bb8bf --- /dev/null +++ b/shared/relay/healthcheck/env_test.go @@ -0,0 +1,36 @@ +package healthcheck + +import ( + "os" + "testing" +) + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go index b3503d5db..90f795bbe 100644 --- a/shared/relay/healthcheck/receiver.go +++ b/shared/relay/healthcheck/receiver.go @@ -7,10 +7,15 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - heartbeatTimeout = healthCheckInterval + 10*time.Second +const ( + defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second ) +type ReceiverOptions struct { + HeartbeatTimeout time.Duration + AttemptThreshold int +} + // Receiver is a healthcheck receiver // It will listen for heartbeat and check if the heartbeat is not received in a certain time // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work @@ -27,6 +32,23 @@ type Receiver struct { // NewReceiver creates a new healthcheck receiver and start the timer in the background func NewReceiver(log *log.Entry) *Receiver { + opts := ReceiverOptions{ + HeartbeatTimeout: defaultHeartbeatTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewReceiverWithOpts(log, opts) +} + +func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { + heartbeatTimeout := opts.HeartbeatTimeout + if heartbeatTimeout <= 0 { + heartbeatTimeout = defaultHeartbeatTimeout + } + attemptThreshold := opts.AttemptThreshold + if attemptThreshold <= 0 { + attemptThreshold = defaultAttemptThreshold + } + ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ @@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver { ctx: ctx, ctxCancel: ctxCancel, heartbeat: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + attemptThreshold: attemptThreshold, } - go r.waitForHealthcheck() + go r.waitForHealthcheck(heartbeatTimeout) return r } @@ -55,7 +77,7 @@ func (r *Receiver) Stop() { r.ctxCancel() } -func (r *Receiver) waitForHealthcheck() { +func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) { ticker := time.NewTicker(heartbeatTimeout) defer ticker.Stop() defer r.ctxCancel() diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go index 2794159f6..b20cc5124 100644 --- a/shared/relay/healthcheck/receiver_test.go +++ b/shared/relay/healthcheck/receiver_test.go @@ -2,31 +2,18 @@ package healthcheck import ( "context" - "fmt" - "os" - "sync" "testing" "time" log "github.com/sirupsen/logrus" ) -// Mutex to protect global variable access in tests -var testMutex sync.Mutex - func TestNewReceiver(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 5 * time.Second - testMutex.Unlock() - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 5 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) { } func TestNewReceiverNotReceive(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 1 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 1 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) { } func TestNewReceiverAck(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 2 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 2 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() r.Heartbeat() @@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - testMutex.Lock() - originalInterval := healthCheckInterval - originalTimeout := heartbeatTimeout - healthCheckInterval = 1 * time.Second - heartbeatTimeout = healthCheckInterval + 500*time.Millisecond - testMutex.Unlock() + healthCheckInterval := 1 * time.Second - defer func() { - testMutex.Lock() - healthCheckInterval = originalInterval - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := ReceiverOptions{ + HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond, + AttemptThreshold: tc.threshold, + } - receiver := NewReceiver(log.WithField("test_name", tc.name)) + receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts) - testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval if tc.resetCounterOnce { receiver.Heartbeat() - t.Logf("reset counter once") } select { @@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { } t.Fatalf("should have timed out before %s", testTimeout) } - }) } } diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go index 57b3015ec..771e94206 100644 --- a/shared/relay/healthcheck/sender.go +++ b/shared/relay/healthcheck/sender.go @@ -2,52 +2,76 @@ package healthcheck import ( "context" - "os" - "strconv" "time" log "github.com/sirupsen/logrus" ) const ( - defaultAttemptThreshold = 1 - defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" + defaultAttemptThreshold = 1 + + defaultHealthCheckInterval = 25 * time.Second + defaultHealthCheckTimeout = 20 * time.Second ) -var ( - healthCheckInterval = 25 * time.Second - healthCheckTimeout = 20 * time.Second -) +type SenderOptions struct { + HealthCheckInterval time.Duration + HealthCheckTimeout time.Duration + AttemptThreshold int +} // Sender is a healthcheck sender // It will send healthcheck signal to the receiver // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { - log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} + log *log.Entry + healthCheckInterval time.Duration + timeout time.Duration + ack chan struct{} alive bool attemptThreshold int } -// NewSender creates a new healthcheck sender -func NewSender(log *log.Entry) *Sender { +func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender { + if opts.HealthCheckInterval <= 0 { + opts.HealthCheckInterval = defaultHealthCheckInterval + } + if opts.HealthCheckTimeout <= 0 { + opts.HealthCheckTimeout = defaultHealthCheckTimeout + } + if opts.AttemptThreshold <= 0 { + opts.AttemptThreshold = defaultAttemptThreshold + } hc := &Sender{ - log: log, - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - ack: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + log: log, + healthCheckInterval: opts.HealthCheckInterval, + timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout, + ack: make(chan struct{}, 1), + attemptThreshold: opts.AttemptThreshold, } return hc } +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { + opts := SenderOptions{ + HealthCheckInterval: defaultHealthCheckInterval, + HealthCheckTimeout: defaultHealthCheckTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewSenderWithOpts(log, opts) +} + // OnHCResponse sends an acknowledgment signal to the sender func (hc *Sender) OnHCResponse() { select { @@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() { } func (hc *Sender) StartHealthCheck(ctx context.Context) { - ticker := time.NewTicker(healthCheckInterval) + ticker := time.NewTicker(hc.healthCheckInterval) defer ticker.Stop() - timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + timeoutTicker := time.NewTicker(hc.timeout) defer timeoutTicker.Stop() defer close(hc.HealthCheck) @@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { } } } - -func (hc *Sender) getTimeoutTime() time.Duration { - return healthCheckInterval + healthCheckTimeout -} - -func getAttemptThresholdFromEnv() int { - if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { - threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) - if err != nil { - log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) - return defaultAttemptThreshold - } - return int(threshold) - } - return defaultAttemptThreshold -} diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go index 23446366a..122fe0f16 100644 --- a/shared/relay/healthcheck/sender_test.go +++ b/shared/relay/healthcheck/sender_test.go @@ -2,26 +2,23 @@ package healthcheck import ( "context" - "fmt" - "os" "testing" "time" log "github.com/sirupsen/logrus" ) -func TestMain(m *testing.M) { - // override the health check interval to speed up the test - healthCheckInterval = 2 * time.Second - healthCheckTimeout = 100 * time.Millisecond - code := m.Run() - os.Exit(code) -} +var ( + testOpts = SenderOptions{ + HealthCheckInterval: 2 * time.Second, + HealthCheckTimeout: 100 * time.Millisecond, + } +) func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) select { case <-hc.Timeout: - case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond): t.Fatalf("health check is not timed out") } } func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - originalInterval := healthCheckInterval - originalTimeout := healthCheckTimeout - healthCheckInterval = 1 * time.Second - healthCheckTimeout = 500 * time.Millisecond - - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := SenderOptions{ + HealthCheckInterval: 1 * time.Second, + HealthCheckTimeout: 500 * time.Millisecond, + AttemptThreshold: tc.threshold, + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sender := NewSender(log.WithField("test_name", tc.name)) + sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts) senderExit := make(chan struct{}) go func() { sender.StartHealthCheck(ctx) @@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { } }() - testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval select { case <-sender.Timeout: @@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout }) } } - -//nolint:tenv -func TestGetAttemptThresholdFromEnv(t *testing.T) { - tests := []struct { - name string - envValue string - expected int - }{ - {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, - {"Custom attempt threshold when env is set to a valid integer", "3", 3}, - {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.envValue == "" { - os.Unsetenv(defaultAttemptThresholdEnv) - } else { - os.Setenv(defaultAttemptThresholdEnv, tt.envValue) - } - - result := getAttemptThresholdFromEnv() - if result != tt.expected { - t.Fatalf("Expected %d, got %d", tt.expected, result) - } - - os.Unsetenv(defaultAttemptThresholdEnv) - }) - } -} From bd23ab925e61c79808ef6b805927c922efef02dd Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Sep 2025 15:08:53 +0200 Subject: [PATCH 29/41] [client] Fix ICE latency handling (#4501) The GetSelectedCandidatePair() does not carry the latency information. --- client/internal/peer/worker_ice.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4e85ba0fa..eb886a4d3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -218,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { + if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) { + w.onICESelectedCandidatePair(agent, c1, c2) + }); err != nil { return nil, err } @@ -365,26 +367,17 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { }() } -func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { +func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) - w.muxAgent.Lock() - - pair, err := w.agent.GetSelectedCandidatePair() - if err != nil { - w.log.Warnf("failed to get selected candidate pair: %s", err) - w.muxAgent.Unlock() + pairStat, ok := agent.GetSelectedCandidatePairStats() + if !ok { + w.log.Warnf("failed to get selected candidate pair stats") return } - if pair == nil { - w.log.Warnf("selected candidate pair is nil, cannot proceed") - w.muxAgent.Unlock() - return - } - w.muxAgent.Unlock() - duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second)) + duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second)) if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { w.log.Debugf("failed to update latency for peer: %s", err) return From 3130cce72d59165060a6af643bc529be11da2802 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 15 Sep 2025 21:08:16 +0300 Subject: [PATCH 30/41] [management] Add rule ID validation for policy updates (#4499) --- management/server/policy.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/management/server/policy.go b/management/server/policy.go index 312fd53b2..3adee6397 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -167,10 +167,22 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } + + // TODO: Refactor to support multiple rules per policy + existingRuleIDs := make(map[string]bool) + for _, rule := range existingPolicy.Rules { + existingRuleIDs[rule.ID] = true + } + + for _, rule := range policy.Rules { + if rule.ID != "" && !existingRuleIDs[rule.ID] { + return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + } + } } else { policy.ID = xid.New().String() policy.AccountID = accountID From ec8d83ade442e0d212db4202fdf60bf826c6da62 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:13:29 +0700 Subject: [PATCH 31/41] [client] [UI] Down & Up NetBird Async When Settings Updated [client] [UI] Down & Up NetBird Async When Settings Updated --- client/ui/client_ui.go | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 2403b5d05..25d7380a9 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -529,7 +529,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form { var req proto.SetConfigRequest req.ProfileName = activeProf.Name req.Username = currUser.Username - + if iMngURL != "" { req.ManagementUrl = iMngURL } @@ -563,27 +563,28 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + log.Errorf("get service status: %v", err) + dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) return } - } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + return + } + } + }() } }, OnCancel: func() { From 2c87fa623654c5eef76bc0226062290201eef13a Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Thu, 18 Sep 2025 10:07:42 -0300 Subject: [PATCH 32/41] [android] Add OnLoginSuccess callback to URLOpener interface (#4492) The callback will be fired once login -> internal.Login completes without errors --- client/android/login.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/android/login.go b/client/android/login.go index d8ac645e2..0df78dbc3 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -33,6 +33,7 @@ type ErrListener interface { // the backend want to show an url for the user type URLOpener interface { Open(string) + OnLoginSuccess() } // Auth can register or login new client @@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error { err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) + + if err == nil { + go urlOpener.OnLoginSuccess() + } + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return nil } From dc30dcacce4c322502975f1f491e6774efd7e1e9 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 18 Sep 2025 19:57:07 +0300 Subject: [PATCH 33/41] [management] Filter DNS records to include only peers to connect (#4517) DNS record filtering to only include peers that a peer can connect to, reducing unnecessary DNS data in the peer's network map. - Adds a new `filterZoneRecordsForPeers` function to filter DNS records based on peer connectivity - Modifies `GetPeerNetworkMap` to use filtered DNS records instead of all records in the custom zone - Includes comprehensive test coverage for the new filtering functionality --- management/server/types/account.go | 27 +++++- management/server/types/account_test.go | 109 ++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index 9ac2568a0..ca075b9f6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -302,7 +302,11 @@ func (a *Account) GetPeerNetworkMap( var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -1651,3 +1655,24 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } + +// filterZoneRecordsForPeers filters DNS records to only include peers to connect. +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) + peerIPs := make(map[string]struct{}) + + // Add peer's own IP to include its own DNS records + peerIPs[peer.IP.String()] = struct{}{} + + for _, peerToConnect := range peersToConnect { + peerIPs[peerToConnect.IP.String()] = struct{}{} + } + + for _, record := range customZone.Records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f8ab1d627..cd221b590 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,14 +2,17 @@ package types import ( "context" + "fmt" "net" "net/netip" "slices" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } + +func Test_FilterZoneRecordsForPeers(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + customZone nbdns.CustomZone + peersToConnect []*nbpeer.Peer + expectedRecords []nbdns.SimpleRecord + }{ + { + name: "empty peers to connect", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + { + name: "multiple peers multiple records match", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for i := 1; i <= 100; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + peersToConnect: func() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + peers = append(peers, &nbpeer.Peer{ + ID: fmt.Sprintf("peer%d", i), + IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + }) + } + return peers + }(), + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + { + name: "peers with multiple DNS labels", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, + {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + assert.Equal(t, len(tt.expectedRecords), len(result)) + assert.ElementsMatch(t, tt.expectedRecords, result) + }) + } +} From 90577682e45b5fbcadc7d7ae814a0dbb46a1621f Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 19 Sep 2025 14:06:44 +0300 Subject: [PATCH 34/41] Add a new product demo video (#4520) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ea7655869..2c5ee2ab6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +


@@ -52,7 +53,7 @@ ### Open Source Network Security in a Single Platform -centralized-network-management 1 +https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) From 55126f990cdca39d37e9520488816a0785bf7a0b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 20 Sep 2025 09:31:04 +0200 Subject: [PATCH 35/41] [client] Use native windows sock opts to avoid routing loops (#4314) - Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed - Add methods to return the next best interface after the NetBird interface. - Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method. - Some refactoring to avoid import cycles - Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var --- client/android/client.go | 2 +- client/cmd/down.go | 2 +- client/firewall/iptables/acl_linux.go | 2 +- client/firewall/iptables/router_linux.go | 2 +- client/firewall/iptables/router_linux_test.go | 2 +- client/firewall/nftables/acl_linux.go | 2 +- client/firewall/nftables/router_linux.go | 2 +- {util => client}/grpc/dialer.go | 3 +- client/iface/bind/control.go | 2 +- client/iface/bind/ice_bind.go | 2 +- client/iface/configurer/usp.go | 4 +- client/iface/device/device_kernel_unix.go | 9 +- client/iface/device/device_netstack.go | 2 +- client/iface/udpmux/mux_generic.go | 2 +- client/iface/wgproxy/ebpf/proxy.go | 2 +- client/internal/connect.go | 2 +- client/internal/dns/service_memory.go | 2 +- client/internal/dns/upstream_android.go | 2 +- client/internal/engine.go | 2 + .../internal/netflow/conntrack/conntrack.go | 2 +- client/internal/relay/relay.go | 2 +- client/internal/routemanager/manager.go | 16 +- .../systemops/systemops_android.go | 4 +- .../systemops/systemops_generic.go | 54 +--- .../systemops/systemops_generic_test.go | 16 +- .../routemanager/systemops/systemops_ios.go | 4 +- .../routemanager/systemops/systemops_linux.go | 12 +- .../routemanager/systemops/systemops_unix.go | 4 +- .../systemops/systemops_unix_test.go | 2 +- .../systemops/systemops_windows.go | 148 ++++++++- .../systemops/systemops_windows_test.go | 2 +- client/internal/routemanager/util/ip.go | 14 +- client/internal/stdnet/dialer.go | 2 +- client/internal/stdnet/listener.go | 2 +- client/net/conn.go | 49 +++ client/net/dial.go | 82 +++++ {util => client}/net/dial_ios.go | 0 {util => client}/net/dialer.go | 1 - client/net/dialer_dial.go | 87 ++++++ {util => client}/net/dialer_init_android.go | 0 client/net/dialer_init_generic.go | 7 + {util => client}/net/dialer_init_linux.go | 0 client/net/dialer_init_windows.go | 5 + {util => client}/net/env.go | 1 + client/net/env_android.go | 24 ++ client/net/env_generic.go | 23 ++ {util => client}/net/env_linux.go | 16 +- client/net/env_windows.go | 67 +++++ client/net/hooks/hooks.go | 93 ++++++ client/net/listen.go | 47 +++ {util => client}/net/listen_ios.go | 0 {util => client}/net/listener.go | 6 +- {util => client}/net/listener_init_android.go | 0 client/net/listener_init_generic.go | 7 + {util => client}/net/listener_init_linux.go | 0 client/net/listener_init_windows.go | 8 + client/net/listener_listen.go | 153 ++++++++++ {util => client}/net/listener_listen_ios.go | 0 {util => client}/net/net.go | 14 - {util => client}/net/net_linux.go | 0 {util => client}/net/net_test.go | 0 client/net/net_windows.go | 284 ++++++++++++++++++ {util => client}/net/protectsocket_android.go | 0 flow/client/client.go | 2 +- shared/management/client/grpc.go | 2 +- shared/relay/client/dialer/quic/quic.go | 2 +- shared/relay/client/dialer/ws/ws.go | 2 +- shared/signal/client/grpc.go | 2 +- sharedsock/sock_linux.go | 6 +- util/net/conn.go | 31 -- util/net/dial.go | 58 ---- util/net/dialer_dial.go | 107 ------- util/net/dialer_init_nonlinux.go | 7 - util/net/env_generic.go | 12 - util/net/listen.go | 37 --- util/net/listener_init_nonlinux.go | 7 - util/net/listener_listen.go | 205 ------------- 77 files changed, 1180 insertions(+), 606 deletions(-) rename {util => client}/grpc/dialer.go (98%) create mode 100644 client/net/conn.go create mode 100644 client/net/dial.go rename {util => client}/net/dial_ios.go (100%) rename {util => client}/net/dialer.go (99%) create mode 100644 client/net/dialer_dial.go rename {util => client}/net/dialer_init_android.go (100%) create mode 100644 client/net/dialer_init_generic.go rename {util => client}/net/dialer_init_linux.go (100%) create mode 100644 client/net/dialer_init_windows.go rename {util => client}/net/env.go (94%) create mode 100644 client/net/env_android.go create mode 100644 client/net/env_generic.go rename {util => client}/net/env_linux.go (86%) create mode 100644 client/net/env_windows.go create mode 100644 client/net/hooks/hooks.go create mode 100644 client/net/listen.go rename {util => client}/net/listen_ios.go (100%) rename {util => client}/net/listener.go (81%) rename {util => client}/net/listener_init_android.go (100%) create mode 100644 client/net/listener_init_generic.go rename {util => client}/net/listener_init_linux.go (100%) create mode 100644 client/net/listener_init_windows.go create mode 100644 client/net/listener_listen.go rename {util => client}/net/listener_listen_ios.go (100%) rename {util => client}/net/net.go (81%) rename {util => client}/net/net_linux.go (100%) rename {util => client}/net/net_test.go (100%) create mode 100644 client/net/net_windows.go rename {util => client}/net/protectsocket_android.go (100%) delete mode 100644 util/net/conn.go delete mode 100644 util/net/dial.go delete mode 100644 util/net/dialer_dial.go delete mode 100644 util/net/dialer_init_nonlinux.go delete mode 100644 util/net/env_generic.go delete mode 100644 util/net/listen.go delete mode 100644 util/net/listener_init_nonlinux.go delete mode 100644 util/net/listener_listen.go diff --git a/client/android/client.go b/client/android/client.go index 4b4fcc9be..d2d0c37f6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile diff --git a/client/cmd/down.go b/client/cmd/down.go index 3ce51c678..17c152d22 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -27,7 +27,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7b90000a8..ed8a7403b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -12,7 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1e44c7a4d..081991235 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // constants needed to manage and create iptable rules diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index e9eeff863..3490c5dad 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,7 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func isIptablesSupported() bool { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 52979d257..9ff5b8c92 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index f8fed4d80..e918d0524 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -22,7 +22,7 @@ import ( nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/util/grpc/dialer.go b/client/grpc/dialer.go similarity index 98% rename from util/grpc/dialer.go rename to client/grpc/dialer.go index f6d6d2f04..7ac950d85 100644 --- a/util/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -20,8 +20,9 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" ) func WithCustomDialer() grpc.DialOption { diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go index 89bddf12c..32b07c330 100644 --- a/client/iface/bind/control.go +++ b/client/iface/bind/control.go @@ -3,7 +3,7 @@ package bind import ( wireguard "golang.zx2c4.com/wireguard/conn" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index b74f90d6c..577c7c0c4 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -17,7 +17,7 @@ import ( "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type RecvMessage struct { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 945f1a162..f744e0127 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/monotime" - nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -409,7 +409,7 @@ func toBytes(s string) (int64, error) { } func getFwmark() int { - if nbnet.AdvancedRouting() { + if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { return nbnet.ControlPlaneMark } return 0 diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 2ef6f6b22..cdac43a53 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -15,8 +15,8 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/sharedsock" - nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -101,13 +101,8 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return nil, err } - var udpConn net.PacketConn = rawSock - if !nbnet.AdvancedRouting() { - udpConn = nbnet.WrapPacketConn(rawSock) - } - bindParams := udpmux.UniversalUDPMuxParams{ - UDPConn: udpConn, + UDPConn: nbnet.WrapPacketConn(rawSock), Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 2fcc74809..a6ef47027 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -12,7 +12,7 @@ import ( nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type TunNetstackDevice struct { diff --git a/client/iface/udpmux/mux_generic.go b/client/iface/udpmux/mux_generic.go index cf3043be0..29fc2d834 100644 --- a/client/iface/udpmux/mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -3,7 +3,7 @@ package udpmux import ( - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index fcdc0189d..b899f1694 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/internal/connect.go b/client/internal/connect.go index 33cd4b4a1..c9331baf5 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -34,7 +34,7 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 89d637686..6ef0ab526 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type ServiceViaMemory struct { diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 6b7dcc05e..def281f28 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type upstreamResolver struct { diff --git a/client/internal/engine.go b/client/internal/engine.go index 9dc744434..d4c465efb 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -446,6 +446,8 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } + + // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index dbb4747a5..a4ffa3a25 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,7 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const defaultChannelSize = 100 diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 8c3d5a571..fa208716f 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ProbeResult holds the info about the result of a relay probe request diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index a6775c45a..04513bbe4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,9 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager { notifier := notifier.NewNotifier() sysOps := systemops.NewSysOps(config.WGInterface, notifier) + if runtime.GOOS == "windows" && config.WGInterface != nil { + nbnet.SetVPNInterfaceName(config.WGInterface.Name()) + } + dm := &DefaultManager{ ctx: mCTX, stop: cancel, @@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error { return nil } - if err := m.sysOps.CleanupRouting(nil); err != nil { + if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error { ips := resolveURLsToIPs(initialAddresses) - if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil { return fmt.Errorf("setup routing: %w", err) } @@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { - if err := m.sysOps.CleanupRouting(stateManager); err != nil { + if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") } + + if runtime.GOOS == "windows" { + nbnet.SetVPNInterfaceName("") + } } m.mux.Lock() diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index a375ce832..7cb8dae93 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -12,11 +12,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 128afa2a5..26a548634 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -3,7 +3,6 @@ package systemops import ( - "context" "errors" "fmt" "net" @@ -22,7 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net/hooks" ) const localSubnetsCacheTTL = 15 * time.Minute @@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() + hooks.RemoveWriteHooks() + hooks.RemoveCloseHooks() + hooks.RemoveAddressRemoveHooks() if err := r.refCounter.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) @@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := util.GetPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - + beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } @@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } - afterHook := func(connID nbnet.ConnectionID) error { + afterHook := func(connID hooks.ConnectionID) error { if err := r.refCounter.DecrementWithID(string(connID)); err != nil { return fmt.Errorf("remove route reference: %w", err) } @@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M var merr *multierror.Error for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err)) + continue + } + if err := beforeHook("init", prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err)) } } - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } + hooks.AddWriteHook(beforeHook) + hooks.AddCloseHook(afterHook) - var merr *multierror.Error - for _, ip := range resolvedIPs { - merr = multierror.Append(merr, beforeHook(connID, ip.IP)) - } - return nberrors.FormatErrorOrNil(merr) - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.Decrement(prefix); err != nil { return fmt.Errorf("remove route reference: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index c1c1182bc..32ea38a7a 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" ) type dialer interface { @@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) intf, err := net.InterfaceByName(wgInterface.Name()) @@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 10356eae0..99a363371 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -12,14 +12,14 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index c0cef94ba..bd10f131f 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // IPRule contains IP rule information for debugging @@ -94,15 +94,15 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { - if !nbnet.AdvancedRouting() { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) { + if !advancedRouting { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } defer func() { if err != nil { - if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { - if !nbnet.AdvancedRouting() { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if !advancedRouting { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index f165f7779..d43c2d5bf 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -20,11 +20,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index ad37f611f..959c697e4 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type PacketExpectation struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 4f836897b..95645329e 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "runtime/debug" + "sort" "strconv" "sync" "syscall" @@ -19,9 +20,16 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" ) -const InfiniteLifetime = 0xffffffff +func init() { + nbnet.GetBestInterfaceFunc = GetBestInterface +} + +const ( + InfiniteLifetime = 0xffffffff +) type RouteUpdateType int @@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct { Table [1]MIB_IPFORWARD_ROW2 // Flexible array member } +// candidateRoute represents a potential route for selection during route lookup +type candidateRoute struct { + interfaceIndex uint32 + prefixLength uint8 + routeMetric uint32 + interfaceMetric int +} + // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -177,11 +193,20 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + return r.cleanupRefCounter(stateManager) } @@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { if table != nil { - ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) - if ret != 0 { - log.Warnf("FreeMibTable failed with return code: %d", ret) - } + _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) } } @@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { entryPtr := basePtr + uintptr(i)*entrySize entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - detailed := buildWindowsDetailedRoute(entry) - if detailed != nil { + if detailed := buildWindowsDetailedRoute(entry); detailed != nil { detailedRoutes = append(detailedRoutes, *detailed) } } @@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { return ip } +// parseCandidatesFromTable extracts all matching candidate routes from the routing table +func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute { + var candidates []candidateRoute + entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) + basePtr := uintptr(unsafe.Pointer(&table.Table[0])) + + for i := uint32(0); i < table.NumEntries; i++ { + entryPtr := basePtr + uintptr(i)*entrySize + entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) + + if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil { + candidates = append(candidates, *candidate) + } + } + + return candidates +} + +// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry +// Returns nil if the route doesn't match the destination or should be skipped +func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute { + if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex { + return nil + } + + destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) + if !destPrefix.IsValid() || !destPrefix.Contains(dest) { + return nil + } + + interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family) + + return &candidateRoute{ + interfaceIndex: entry.InterfaceIndex, + prefixLength: entry.DestinationPrefix.PrefixLength, + routeMetric: entry.Metric, + interfaceMetric: interfaceMetric, + } +} + // getInterfaceMetric retrieves the interface metric for a given interface and address family func getInterfaceMetric(interfaceIndex uint32, family int16) int { if interfaceIndex == 0 { @@ -821,6 +882,75 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int { return int(ipInterfaceRow.Metric) } +// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric +func sortRouteCandidates(candidates []candidateRoute) { + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].prefixLength != candidates[j].prefixLength { + return candidates[i].prefixLength > candidates[j].prefixLength + } + if candidates[i].routeMetric != candidates[j].routeMetric { + return candidates[i].routeMetric < candidates[j].routeMetric + } + return candidates[i].interfaceMetric < candidates[j].interfaceMetric + }) +} + +// GetBestInterface finds the best interface for reaching a destination, +// excluding the VPN interface to avoid routing loops. +// +// Route selection priority: +// 1. Longest prefix match (most specific route) +// 2. Lowest route metric +// 3. Lowest interface metric +func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { + var skipInterfaceIndex int + if vpnIntf != "" { + if iface, err := net.InterfaceByName(vpnIntf); err == nil { + skipInterfaceIndex = iface.Index + } else { + return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err) + } + } + + table, err := getWindowsRoutingTable() + if err != nil { + return nil, fmt.Errorf("get routing table: %w", err) + } + defer freeWindowsRoutingTable(table) + + candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex) + + if len(candidates) == 0 { + return nil, fmt.Errorf("no route to %s", dest) + } + + // Sort routes: prefix length -> route metric -> interface metric + sortRouteCandidates(candidates) + + for _, candidate := range candidates { + iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex)) + if err != nil { + log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err) + continue + } + + if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() { + continue + } + + if iface.Flags&net.FlagUp == 0 { + log.Debugf("interface %s is down, trying next route", iface.Name) + continue + } + + log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d", + dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric) + return iface, nil + } + + return nil, fmt.Errorf("no usable interface found for %s", dest) +} + // formatRouteAge formats the route age in seconds to a human-readable string func formatRouteAge(ageSeconds uint32) string { if ageSeconds == 0 { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 523bd0b0d..3561adec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) var ( diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go index ac5a48e37..57ea32f69 100644 --- a/client/internal/routemanager/util/ip.go +++ b/client/internal/routemanager/util/ip.go @@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { if !ok { return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) } + addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) + prefix := netip.PrefixFrom(addr, addr.BitLen()) return prefix, nil } diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index e80adb42b..8961eaa69 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -5,7 +5,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // Dial connects to the address on the named network. diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index 9ce0a5556..d3be1896f 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -6,7 +6,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ListenPacket listens for incoming packets on the given network and address. diff --git a/client/net/conn.go b/client/net/conn.go new file mode 100644 index 000000000..918e7f628 --- /dev/null +++ b/client/net/conn.go @@ -0,0 +1,49 @@ +//go:build !ios + +package net + +import ( + "io" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/net/hooks" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID hooks.ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +func (c *Conn) Close() error { + return closeConn(c.ID, c.Conn) +} + +// TCPConn wraps net.TCPConn to override its Close method to include hook functionality. +type TCPConn struct { + *net.TCPConn + ID hooks.ConnectionID +} + +// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +func (c *TCPConn) Close() error { + return closeConn(c.ID, c.TCPConn) +} + +// closeConn is a helper function to close connections and execute close hooks. +func closeConn(id hooks.ConnectionID, conn io.Closer) error { + err := conn.Close() + + closeHooks := hooks.GetCloseHooks() + for _, hook := range closeHooks { + if err := hook(id); err != nil { + log.Errorf("Error executing close hook: %v", err) + } + } + + return err +} diff --git a/client/net/dial.go b/client/net/dial.go new file mode 100644 index 000000000..041a00e5d --- /dev/null +++ b/client/net/dial.go @@ -0,0 +1,82 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + udpConn, ok := c.Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.TCPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + tcpConn, ok := c.Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn) + } + return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/dial_ios.go b/client/net/dial_ios.go similarity index 100% rename from util/net/dial_ios.go rename to client/net/dial_ios.go diff --git a/util/net/dialer.go b/client/net/dialer.go similarity index 99% rename from util/net/dialer.go rename to client/net/dialer.go index 0786c667e..29bec05a7 100644 --- a/util/net/dialer.go +++ b/client/net/dialer.go @@ -16,6 +16,5 @@ func NewDialer() *Dialer { Dialer: &net.Dialer{}, } dialer.init() - return dialer } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go new file mode 100644 index 000000000..2e1eb53d8 --- /dev/null +++ b/client/net/dialer_dial.go @@ -0,0 +1,87 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + + if CustomRoutingDisabled() || AdvancedRouting() { + return d.Dialer.DialContext(ctx, network, address) + } + + connID := hooks.GenerateConnID() + if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error { + if ctx.Err() != nil { + return ctx.Err() + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + + resolver := customResolver + if resolver == nil { + resolver = net.DefaultResolver + } + + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var merr *multierror.Error + for _, ip := range ips { + prefix, err := util.GetPrefixFromIP(ip.IP) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err)) + continue + } + for _, hook := range writeHooks { + if err := hook(connID, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err)) + } + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/dialer_init_android.go b/client/net/dialer_init_android.go similarity index 100% rename from util/net/dialer_init_android.go rename to client/net/dialer_init_android.go diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go new file mode 100644 index 000000000..18ebc6ad1 --- /dev/null +++ b/client/net/dialer_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (d *Dialer) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/dialer_init_linux.go b/client/net/dialer_init_linux.go similarity index 100% rename from util/net/dialer_init_linux.go rename to client/net/dialer_init_linux.go diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go new file mode 100644 index 000000000..6eefe5b1e --- /dev/null +++ b/client/net/dialer_init_windows.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyUnicastIFToSocket +} diff --git a/util/net/env.go b/client/net/env.go similarity index 94% rename from util/net/env.go rename to client/net/env.go index 32425665d..8f326ca88 100644 --- a/util/net/env.go +++ b/client/net/env.go @@ -11,6 +11,7 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) // CustomRoutingDisabled returns true if custom routing is disabled. diff --git a/client/net/env_android.go b/client/net/env_android.go new file mode 100644 index 000000000..9d89951a1 --- /dev/null +++ b/client/net/env_android.go @@ -0,0 +1,24 @@ +//go:build android + +package net + +// Init initializes the network environment for Android +func Init() { + // No initialization needed on Android +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on Android since we cannot handle routes dynamically. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on Android +func SetVPNInterfaceName(name string) { + // No-op on Android - not needed for Android VPN service +} + +// GetVPNInterfaceName returns empty string on Android +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_generic.go b/client/net/env_generic.go new file mode 100644 index 000000000..f467930c3 --- /dev/null +++ b/client/net/env_generic.go @@ -0,0 +1,23 @@ +//go:build !linux && !windows && !android + +package net + +// Init initializes the network environment (no-op on non-Linux/Windows platforms) +func Init() { + // No-op on non-Linux/Windows platforms +} + +// AdvancedRouting returns false on non-Linux/Windows platforms +func AdvancedRouting() bool { + return false +} + +// SetVPNInterfaceName is a no-op on non-Windows platforms +func SetVPNInterfaceName(name string) { + // No-op on non-Windows platforms +} + +// GetVPNInterfaceName returns empty string on non-Windows platforms +func GetVPNInterfaceName() string { + return "" +} diff --git a/util/net/env_linux.go b/client/net/env_linux.go similarity index 86% rename from util/net/env_linux.go rename to client/net/env_linux.go index 3159f6462..82d9a74a8 100644 --- a/util/net/env_linux.go +++ b/client/net/env_linux.go @@ -17,8 +17,7 @@ import ( const ( // these have the same effect, skip socket env supported for backward compatibility - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" - envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" ) var advancedRoutingSupported bool @@ -27,6 +26,7 @@ func Init() { advancedRoutingSupported = checkAdvancedRoutingSupport() } +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes func AdvancedRouting() bool { return advancedRoutingSupported } @@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool { } func CheckFwmarkSupport() bool { - // temporarily enable advanced routing to check fwmarks are supported + // temporarily enable advanced routing to check if fwmarks are supported old := advancedRoutingSupported advancedRoutingSupported = true defer func() { @@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool { } return true } + +// SetVPNInterfaceName is a no-op on Linux +func SetVPNInterfaceName(name string) { + // No-op on Linux - not needed for fwmark-based routing +} + +// GetVPNInterfaceName returns empty string on Linux +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_windows.go b/client/net/env_windows.go new file mode 100644 index 000000000..7e8868ba5 --- /dev/null +++ b/client/net/env_windows.go @@ -0,0 +1,67 @@ +//go:build windows + +package net + +import ( + "os" + "strconv" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +var ( + vpnInterfaceName string + vpnInitMutex sync.RWMutex + + advancedRoutingSupported bool +) + +func Init() { + advancedRoutingSupported = checkAdvancedRoutingSupport() +} + +func checkAdvancedRoutingSupport() bool { + var err error + var legacyRouting bool + if val := os.Getenv(envUseLegacyRouting); val != "" { + legacyRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + } + } + + if legacyRouting || netstack.IsEnabled() { + log.Info("advanced routing has been requested to be disabled") + return false + } + + log.Info("system supports advanced routing") + + return true +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes +func AdvancedRouting() bool { + return advancedRoutingSupported +} + +// GetVPNInterfaceName returns the stored VPN interface name +func GetVPNInterfaceName() string { + vpnInitMutex.RLock() + defer vpnInitMutex.RUnlock() + return vpnInterfaceName +} + +// SetVPNInterfaceName sets the VPN interface name for lazy initialization +func SetVPNInterfaceName(name string) { + vpnInitMutex.Lock() + defer vpnInitMutex.Unlock() + vpnInterfaceName = name + + if name != "" { + log.Infof("VPN interface name set to %s for route exclusion", name) + } +} diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go new file mode 100644 index 000000000..93d8e18ef --- /dev/null +++ b/client/net/hooks/hooks.go @@ -0,0 +1,93 @@ +package hooks + +import ( + "net/netip" + "slices" + "sync" + + "github.com/google/uuid" +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} + +type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error +type CloseHookFunc func(connID ConnectionID) error +type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + +var ( + hooksMutex sync.RWMutex + + writeHooks []WriteHookFunc + closeHooks []CloseHookFunc + addressRemoveHooks []AddressRemoveHookFunc +) + +// AddWriteHook allows adding a new hook to be executed before writing/dialing. +func AddWriteHook(hook WriteHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = append(writeHooks, hook) +} + +// AddCloseHook allows adding a new hook to be executed on connection close. +func AddCloseHook(hook CloseHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = append(closeHooks, hook) +} + +// RemoveWriteHooks removes all write hooks. +func RemoveWriteHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = nil +} + +// RemoveCloseHooks removes all close hooks. +func RemoveCloseHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = nil +} + +// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddAddressRemoveHook(hook AddressRemoveHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = append(addressRemoveHooks, hook) +} + +// RemoveAddressRemoveHooks removes all listener address hooks. +func RemoveAddressRemoveHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = nil +} + +// GetWriteHooks returns a copy of the current write hooks. +func GetWriteHooks() []WriteHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(writeHooks) +} + +// GetCloseHooks returns a copy of the current close hooks. +func GetCloseHooks() []CloseHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(closeHooks) +} + +// GetAddressRemoveHooks returns a copy of the current listener address remove hooks. +func GetAddressRemoveHooks() []AddressRemoveHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(addressRemoveHooks) +} diff --git a/client/net/listen.go b/client/net/listen.go new file mode 100644 index 000000000..da7262806 --- /dev/null +++ b/client/net/listen.go @@ -0,0 +1,47 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *PacketConn: + // Legacy routing: wrapped connection for hooks + udpConn, ok := c.PacketConn.(*net.UDPConn) + if !ok { + if err := c.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/listen_ios.go b/client/net/listen_ios.go similarity index 100% rename from util/net/listen_ios.go rename to client/net/listen_ios.go diff --git a/util/net/listener.go b/client/net/listener.go similarity index 81% rename from util/net/listener.go rename to client/net/listener.go index f4d769f58..4c2f53c05 100644 --- a/util/net/listener.go +++ b/client/net/listener.go @@ -7,14 +7,12 @@ import ( // ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before // responding via the socket and after closing. This can be used to bypass the VPN for listeners. type ListenerConfig struct { - *net.ListenConfig + net.ListenConfig } // NewListener creates a new ListenerConfig instance. func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } + listener := &ListenerConfig{} listener.init() return listener diff --git a/util/net/listener_init_android.go b/client/net/listener_init_android.go similarity index 100% rename from util/net/listener_init_android.go rename to client/net/listener_init_android.go diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go new file mode 100644 index 000000000..4f8f17ab2 --- /dev/null +++ b/client/net/listener_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (l *ListenerConfig) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/listener_init_linux.go b/client/net/listener_init_linux.go similarity index 100% rename from util/net/listener_init_linux.go rename to client/net/listener_init_linux.go diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go new file mode 100644 index 000000000..a9399b5f1 --- /dev/null +++ b/client/net/listener_init_windows.go @@ -0,0 +1,8 @@ +package net + +func (l *ListenerConfig) init() { + // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses. + // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case + // the interface will be selected that serves the default route. + l.ListenConfig.Control = applyUnicastIFToSocket +} diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go new file mode 100644 index 000000000..0bb5ad67d --- /dev/null +++ b/client/net/listener_listen.go @@ -0,0 +1,153 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() || AdvancedRouting() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := hooks.GenerateConnID() + + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.UDPConn) +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *PacketConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen()) + + addressRemoveHooks := hooks.GetAddressRemoveHooks() + if len(addressRemoveHooks) == 0 { + return + } + + for _, hook := range addressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + +// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality +func WrapPacketConn(conn net.PacketConn) net.PacketConn { + if AdvancedRouting() { + // hooks not required for advanced routing + return conn + } + return &PacketConn{ + PacketConn: conn, + ID: hooks.GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error { + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded { + return nil + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr) + } + + prefix, err := util.GetPrefixFromIP(udpAddr.IP) + if err != nil { + return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err) + } + + log.Debugf("Listener resolved IP for %s: %s", addr, prefix) + + var merr *multierror.Error + for _, hook := range writeHooks { + if err := hook(id, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/listener_listen_ios.go b/client/net/listener_listen_ios.go similarity index 100% rename from util/net/listener_listen_ios.go rename to client/net/listener_listen_ios.go diff --git a/util/net/net.go b/client/net/net.go similarity index 81% rename from util/net/net.go rename to client/net/net.go index fdcf4ee6a..a97de9d59 100644 --- a/util/net/net.go +++ b/client/net/net.go @@ -5,8 +5,6 @@ import ( "math/big" "net" "net/netip" - - "github.com/google/uuid" ) const ( @@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool { return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper } -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -type AddHookFunc func(connID ConnectionID, IP net.IP) error -type RemoveHookFunc func(connID ConnectionID) error - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} - func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP addr := network.Addr().AsSlice() diff --git a/util/net/net_linux.go b/client/net/net_linux.go similarity index 100% rename from util/net/net_linux.go rename to client/net/net_linux.go diff --git a/util/net/net_test.go b/client/net/net_test.go similarity index 100% rename from util/net/net_test.go rename to client/net/net_test.go diff --git a/client/net/net_windows.go b/client/net/net_windows.go new file mode 100644 index 000000000..649d83aaf --- /dev/null +++ b/client/net/net_windows.go @@ -0,0 +1,284 @@ +package net + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "syscall" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + IpUnicastIf = 31 + Ipv6UnicastIf = 31 + + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options + Ipv6V6only = 27 +) + +// GetBestInterfaceFunc is set at runtime to avoid import cycle +var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error) + +// nativeToBigEndian converts a uint32 from native byte order to big-endian +func nativeToBigEndian(v uint32) uint32 { + return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24 +} + +// parseDestinationAddress parses the destination address from various formats +func parseDestinationAddress(network, address string) (netip.Addr, error) { + if address == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + if addrPort, err := netip.ParseAddrPort(address); err == nil { + return addrPort.Addr(), nil + } + + if dest, err := netip.ParseAddr(address); err == nil { + return dest, nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + // No port, treat whole string as host + host = address + } + + if host == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil || len(ips) == 0 { + return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err) + } + + dest, ok := netip.AddrFromSlice(ips[0].IP) + if !ok { + return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP) + } + + if ips[0].Zone != "" { + dest = dest.WithZone(ips[0].Zone) + } + + return dest, nil +} + +func getInterfaceFromZone(zone string) *net.Interface { + if zone == "" { + return nil + } + + idx, err := strconv.Atoi(zone) + if err != nil { + log.Debugf("invalid zone format for Windows (expected numeric): %s", zone) + return nil + } + + iface, err := net.InterfaceByIndex(idx) + if err != nil { + log.Debugf("failed to get interface by index %d from zone: %v", idx, err) + return nil + } + + return iface +} + +type interfaceSelection struct { + iface4 *net.Interface + iface6 *net.Interface +} + +func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection { + iface := getInterfaceFromZone(zone) + if iface == nil { + return nil + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface} + } + return &interfaceSelection{iface4: iface} +} + +func selectInterfaceForUnspecified() (*interfaceSelection, error) { + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + var result interfaceSelection + vpnIfaceName := GetVPNInterfaceName() + + if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil { + result.iface4 = iface4 + } else { + log.Debugf("No IPv4 default route found: %v", err) + } + + if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil { + result.iface6 = iface6 + } else { + log.Debugf("No IPv6 default route found: %v", err) + } + + if result.iface4 == nil && result.iface6 == nil { + return nil, errors.New("no default routes found") + } + + return &result, nil +} + +func selectInterface(dest netip.Addr) (*interfaceSelection, error) { + if zone := dest.Zone(); zone != "" { + if selection := selectInterfaceForZone(dest, zone); selection != nil { + return selection, nil + } + } + + if dest.IsUnspecified() { + return selectInterfaceForUnspecified() + } + + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName()) + if err != nil { + return nil, fmt.Errorf("find route for %s: %w", dest, err) + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface}, nil + } + return &interfaceSelection{iface4: iface}, nil +} + +func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error { + ifaceIndexBE := nativeToBigEndian(uint32(iface.Index)) + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil { + return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error { + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil { + return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error { + // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.) + // Never generic ones (udp, tcp, ip) + + switch { + case strings.HasSuffix(network, "4"): + // IPv4-only socket (udp4, tcp4, ip4) + return setUnicastIfIPv4(fd, network, selection, address) + + case strings.HasSuffix(network, "6"): + // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only + return setUnicastIfIPv6(fd, network, selection, address) + } + + // Shouldn't reach here based on Go's documented behavior + return fmt.Errorf("unexpected network type: %s", network) +} + +func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error { + if selection.iface4 == nil { + return nil + } + + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address) + return nil +} + +func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error { + isDualStack := checkDualStack(fd) + + // For dual-stack sockets, also set the IPv4 option + if isDualStack && selection.iface4 != nil { + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address) + } + + if selection.iface6 == nil { + return nil + } + + if err := setIPv6UnicastIF(fd, selection.iface6); err != nil { + return err + } + + log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address) + return nil +} + +func checkDualStack(fd uintptr) bool { + var v6Only int + v6OnlyLen := int32(unsafe.Sizeof(v6Only)) + err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen) + return err == nil && v6Only == 0 +} + +// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address +func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + dest, err := parseDestinationAddress(network, address) + if err != nil { + return err + } + + dest = dest.Unmap() + + if !dest.IsValid() { + return fmt.Errorf("invalid destination address for %s", address) + } + + selection, err := selectInterface(dest) + if err != nil { + return err + } + + var controlErr error + err = c.Control(func(fd uintptr) { + controlErr = setUnicastIf(fd, network, selection, address) + }) + + if err != nil { + return fmt.Errorf("control: %w", err) + } + + return controlErr +} diff --git a/util/net/protectsocket_android.go b/client/net/protectsocket_android.go similarity index 100% rename from util/net/protectsocket_android.go rename to client/net/protectsocket_android.go diff --git a/flow/client/client.go b/flow/client/client.go index 949824065..603fd6882 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -20,9 +20,9 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) type GRPCClient struct { diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index dc26253e9..03cc5aec3 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -17,11 +17,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index b496f6a9b..967e18d79 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" quictls "github.com/netbirdio/netbird/shared/relay/tls" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 109651f5d..ef6bd6b3c 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 82ab678f4..48d1ff04f 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -16,10 +16,10 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index d4fedc492..bc2d4d1be 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -93,7 +93,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err) } var sockErr error @@ -102,7 +102,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err) } } diff --git a/util/net/conn.go b/util/net/conn.go deleted file mode 100644 index 26693f841..000000000 --- a/util/net/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !ios - -package net - -import ( - "net" - - log "github.com/sirupsen/logrus" -) - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} diff --git a/util/net/dial.go b/util/net/dial.go deleted file mode 100644 index 595311492..000000000 --- a/util/net/dial.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build !ios - -package net - -import ( - "fmt" - "net" - - log "github.com/sirupsen/logrus" -) - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go deleted file mode 100644 index 1659b6220..000000000 --- a/util/net/dialer_dial.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHooks removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - log.Debugf("Dialing %s %s", network, address) - - if CustomRoutingDisabled() { - return d.Dialer.DialContext(ctx, network, address) - } - - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := callDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go deleted file mode 100644 index 8c57ebbaa..000000000 --- a/util/net/dialer_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (d *Dialer) init() { - // implemented on Linux and Android only -} diff --git a/util/net/env_generic.go b/util/net/env_generic.go deleted file mode 100644 index 6d142a838..000000000 --- a/util/net/env_generic.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !linux || android - -package net - -func Init() { - // nothing to do on non-linux -} - -func AdvancedRouting() bool { - // non-linux currently doesn't support advanced routing - return false -} diff --git a/util/net/listen.go b/util/net/listen.go deleted file mode 100644 index 3ae8a9435..000000000 --- a/util/net/listen.go +++ /dev/null @@ -1,37 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go deleted file mode 100644 index 80f6f7f1a..000000000 --- a/util/net/listener_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (l *ListenerConfig) init() { - // implemented on Linux and Android only -} diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go deleted file mode 100644 index 4060ab49a..000000000 --- a/util/net/listener_listen.go +++ /dev/null @@ -1,205 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. -type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc - listenerAddressRemoveHooksMutex sync.RWMutex - listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. -func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) -} - -// RemoveListenerHooks removes all listener hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil - - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - if CustomRoutingDisabled() { - return l.ListenConfig.ListenPacket(ctx, network, address) - } - - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -// RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *PacketConn) RemoveAddress(addr string) { - if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { - return - } - - ipStr, _, err := net.SplitHostPort(addr) - if err != nil { - log.Errorf("Error splitting IP address and port: %v", err) - return - } - - ipAddr, err := netip.ParseAddr(ipStr) - if err != nil { - log.Errorf("Error parsing IP address %s: %v", ipStr, err) - return - } - - prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) - - listenerAddressRemoveHooksMutex.RLock() - defer listenerAddressRemoveHooksMutex.RUnlock() - - for _, hook := range listenerAddressRemoveHooks { - if err := hook(c.ID, prefix); err != nil { - log.Errorf("Error executing listener address remove hook: %v", err) - } - } -} - - -// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality -func WrapPacketConn(conn net.PacketConn) *PacketConn { - return &PacketConn{ - PacketConn: conn, - ID: GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -} From ead1c618ba9a4d850fd1bef1ff7c2ddf444dc653 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 20 Sep 2025 10:00:18 +0200 Subject: [PATCH 36/41] [client] Do not run up cmd if not needed in docker (#4508) optimizes the NetBird client startup process by avoiding unnecessary login commands when the peer is already authenticated. The changes increase the default login timeout and expand the log message patterns used to detect successful authentication. - Increased default login timeout from 1 to 5 seconds for more reliable authentication detection - Enhanced log pattern matching to detect both registration and ready states - Added extended regex support for more flexible pattern matching --- client/Dockerfile | 2 +- client/netbird-entrypoint.sh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/Dockerfile b/client/Dockerfile index e19a09909..b2f627409 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,7 +18,7 @@ ENV \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_LOGIN_TIMEOUT="5" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 2422d2683..7c9fa021a 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -2,7 +2,7 @@ set -eEuo pipefail : ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() @@ -39,7 +39,7 @@ wait_for_message() { info "not waiting for log line ${message@Q} due to zero timeout." elif test -n "${log_file_path}"; then info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) else info "log file unsupported, sleeping for ${timeout} seconds..." sleep "${timeout}" @@ -81,7 +81,7 @@ wait_for_daemon_startup() { login_if_needed() { local timeout="${1}" - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then info "already logged in, skipping 'netbird up'..." else info "logging in..." From e254b4cde55e9dfe565c5b1a1dd18f53014a0a0c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 20 Sep 2025 05:24:04 -0300 Subject: [PATCH 37/41] [misc] Update SIGN_PIPE_VER to version 0.0.23 (#4521) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7be52259b..e9741f541 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.22" + SIGN_PIPE_VER: "v0.0.23" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" From 998fb30e1eb9651f4ad32e9d50f33c0da902dff7 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 20 Sep 2025 22:14:01 +0200 Subject: [PATCH 38/41] [client] Check the client status in the earlier phase (#4509) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on: • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops Key Changes Core Status Improvements: - Added waitForReady optional field to StatusRequest proto (daemon.proto:190) - Enhanced status checking logic to detect client state changes earlier in the connection process - Improved handling of client permanent exit scenarios from retry loops Connection & Context Management: - Fixed context cancellation in management and signal client retry mechanisms - Added proper context propagation for Login operations - Enhanced gRPC connection handling with better timeout management Error Handling & Cleanup: - Moved feedback channels to upper layers for better separation of concerns - Improved error handling patterns throughout the client server implementation - Fixed synchronization issues and removed debug logging --- client/cmd/root.go | 2 +- client/cmd/up.go | 4 +- client/embed/embed.go | 2 +- client/grpc/dialer.go | 4 +- client/proto/daemon.pb.go | 22 ++++-- client/proto/daemon.proto | 2 + client/server/server.go | 126 +++++++++++++++++++++++-------- client/server/server_test.go | 16 ++-- shared/management/client/grpc.go | 2 +- shared/signal/client/grpc.go | 2 +- 10 files changed, 128 insertions(+), 54 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 5084bd38a..11e5228f1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*3) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() return grpc.DialContext( diff --git a/client/cmd/up.go b/client/cmd/up.go index e686625d6..d047c041e 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{}) + status, err := client.Status(ctx, &proto.StatusRequest{ + WaitForReady: func() *bool { b := true; return &b }(), + }) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index de83f9d96..0bfc7a37c 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}, 1) + run := make(chan struct{}) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7ac950d85..69e3f088c 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -58,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -72,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { })) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() conn, err := grpc.DialContext( diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index c633afc83..841e3c0f7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v5.29.3 +// protoc v6.32.1 // source: daemon.proto package proto @@ -794,8 +794,10 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } +func (x *StatusRequest) GetWaitForReady() bool { + if x != nil && x.WaitForReady != nil { + return *x.WaitForReady + } + return false +} + type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"g\n" + + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + + "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + + "\r_waitForReady\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -5231,6 +5242,7 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 0cd3579b9..5b27b4d98 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -186,6 +186,8 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + optional bool waitForReady = 3; } message StatusResponse{ diff --git a/client/server/server.go b/client/server/server.go index fae342f78..e6de608c5 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -67,6 +67,7 @@ type Server struct { proto.UnimplementedDaemonServiceServer clientRunning bool // protected by mutex clientRunningChan chan struct{} + clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -106,6 +107,10 @@ func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + if s.clientRunning { + return nil + } + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -175,12 +180,10 @@ func (s *Server) Start() error { return nil } - if s.clientRunning { - return nil - } s.clientRunning = true - s.clientRunningChan = make(chan struct{}, 1) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan) + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -211,7 +214,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -261,6 +264,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil if err := backoff.Retry(runOperation, backOff); err != nil { log.Errorf("operation failed: %v", err) } + + if giveUpChan != nil { + close(giveUpChan) + } } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -379,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(s.rootCtx) + ctx, cancel := context.WithCancel(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -389,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(ctx) + state := internal.CtxGetState(s.rootCtx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -606,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() + if s.clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return nil, err + } + if status == internal.StatusNeedsLogin { + s.actCancel() + } + s.mutex.Unlock() + + return s.waitForUp(callerCtx) + } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -621,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } + if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but . + // it should be nil here, but in case it isn't we cancel it. if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) - md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -673,26 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + return s.waitForUp(callerCtx) +} + +// todo: handle potential race conditions +func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - if !s.clientRunning { - s.clientRunning = true - s.clientRunningChan = make(chan struct{}, 1) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan) - } - for { - select { - case <-s.clientRunningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() - } + select { + case <-s.clientGiveUpChan: + return nil, fmt.Errorf("client gave up to connect") + case <-s.clientRunningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } @@ -966,12 +992,46 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - s.mutex.Lock() - defer s.mutex.Unlock() + clientRunning := s.clientRunning + s.mutex.Unlock() + + if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + return nil, err + } + + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + loop: + for { + select { + case <-s.clientGiveUpChan: + ticker.Stop() + break loop + case <-s.clientRunningChan: + ticker.Stop() + break loop + case <-ticker.C: + status, err := state.Status() + if err != nil { + continue + } + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + continue + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { diff --git a/client/server/server_test.go b/client/server/server_test.go index 45a1aa5c7..755925003 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) { profName := "default" + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: u.String(), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) - err = s.Start() require.NoError(t, err) - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - s.config = &profilemanager.Config{ - ManagementURL: u, - } - upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) + log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 03cc5aec3..f30e965be 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 48d1ff04f..5ca0c0282 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err From 5853b5553c9b9e80322add7ae5076ba5cea2b74c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 22 Sep 2025 14:32:00 +0200 Subject: [PATCH 39/41] [client] Skip interface for route lookup if it doesn't exist (#4524) --- client/internal/routemanager/systemops/systemops_windows.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 95645329e..7bce6af80 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -908,7 +908,8 @@ func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { if iface, err := net.InterfaceByName(vpnIntf); err == nil { skipInterfaceIndex = iface.Index } else { - return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err) + // not critical, if we cannot get ahold of the interface then we won't need to skip it + log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) } } From 58faa341d2d6dd9a3821ca9a32d7035b4f933381 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:06:10 +0200 Subject: [PATCH 40/41] [management] Add logs for update channel (#4527) --- management/server/grpcserver.go | 1 + management/server/user.go | 1 + 2 files changed, 2 insertions(+) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 27d54e6c2..60a00207e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -258,6 +258,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } diff --git a/management/server/user.go b/management/server/user.go index 3c7c3f433..d40d33c6a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service + log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) } From 644ed4b934452d578ec4854ebda311cd42341f08 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:36:26 +0700 Subject: [PATCH 41/41] [client] Add WireGuard interface lifecycle monitoring (#4370) * [client] Add WireGuard interface lifecycle monitoring --- client/internal/engine.go | 23 +++++++ client/internal/wg_iface_monitor.go | 98 +++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 client/internal/wg_iface_monitor.go diff --git a/client/internal/engine.go b/client/internal/engine.go index d4c465efb..828bc6e94 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -198,6 +198,10 @@ type Engine struct { latestSyncResponse *mgmProto.SyncResponse connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + wgIfaceMonitorWg sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -341,6 +345,9 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + // Stop WireGuard interface monitor and wait for it to exit + e.wgIfaceMonitorWg.Wait() + return nil } @@ -479,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() + + // monitor WireGuard interface lifecycle and restart engine on changes + e.wgIfaceMonitor = NewWGIfaceMonitor() + e.wgIfaceMonitorWg.Add(1) + + go func() { + defer e.wgIfaceMonitorWg.Done() + + if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { + log.Infof("WireGuard interface monitor: %s, restarting engine", err) + e.restartEngine() + } else if err != nil { + log.Warnf("WireGuard interface monitor: %s", err) + } + }() + return nil } diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go new file mode 100644 index 000000000..78d70c15b --- /dev/null +++ b/client/internal/wg_iface_monitor.go @@ -0,0 +1,98 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine +// if the interface is deleted externally while the engine is running. +type WGIfaceMonitor struct { + done chan struct{} +} + +// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. +func NewWGIfaceMonitor() *WGIfaceMonitor { + return &WGIfaceMonitor{ + done: make(chan struct{}), + } +} + +// Start begins monitoring the WireGuard interface. +// It relies on the provided context cancellation to stop. +func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { + defer close(m.done) + + // Skip on mobile platforms as they handle interface lifecycle differently + if runtime.GOOS == "android" || runtime.GOOS == "ios" { + log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) + return false, errors.New("not supported on mobile platforms") + } + + if ifaceName == "" { + log.Debugf("Interface monitor: empty interface name, skipping monitor") + return false, errors.New("empty interface name") + } + + // Get initial interface index to track the specific interface instance + expectedIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) + return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) + } + + log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } + +} + +// getInterfaceIndex returns the index of a network interface by name. +// Returns an error if the interface is not found. +func getInterfaceIndex(name string) (int, error) { + if name == "" { + return 0, fmt.Errorf("empty interface name") + } + ifi, err := net.InterfaceByName(name) + if err != nil { + // Check if it's specifically a "not found" error + if errors.Is(err, &net.OpError{}) { + // On some systems, this might be a "not found" error + return 0, fmt.Errorf("interface not found: %w", err) + } + return 0, fmt.Errorf("failed to lookup interface: %w", err) + } + if ifi == nil { + return 0, fmt.Errorf("interface not found") + } + return ifi.Index, nil +}