diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d7564c353..fd1007bb4 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9fa4e51b2..f4c5be70a 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 772997575..54ad47e55 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { return nil, "", err } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 24dfb641b..ff2c27ac3 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -163,7 +163,8 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider()) + peerSerialCache := nbgrpc.NewPeerSerialCache(context.Background(), s.CacheStore(), nbgrpc.DefaultPeerSerialCacheTTL) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), peerSerialCache) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/shared/grpc/peer_serial_cache.go b/management/internals/shared/grpc/peer_serial_cache.go new file mode 100644 index 000000000..67e0a0895 --- /dev/null +++ b/management/internals/shared/grpc/peer_serial_cache.go @@ -0,0 +1,89 @@ +package grpc + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" + log "github.com/sirupsen/logrus" +) + +const ( + peerSerialCacheKeyPrefix = "peer-sync:" + + // DefaultPeerSerialCacheTTL bounds how long a cached serial survives. If the + // cache write on a full-map send ever drops, entries naturally expire and + // the next Sync falls back to the full path, re-priming the cache. + DefaultPeerSerialCacheTTL = 24 * time.Hour +) + +// PeerSerialCache records the NetworkMap serial and meta hash last delivered to +// each peer on Sync. Lookups are used to skip full network map computation when +// the peer already has the latest state. Backed by the shared cache store so +// entries survive management replicas sharing a Redis instance. +type PeerSerialCache struct { + cache *cache.Cache[string] + ctx context.Context + ttl time.Duration +} + +// NewPeerSerialCache creates a cache wrapper bound to the shared cache store. +// The ttl is applied to every Set call; entries older than ttl are treated as +// misses so the server eventually converges to delivering a full map even if +// an earlier Set was lost. +func NewPeerSerialCache(ctx context.Context, cacheStore store.StoreInterface, ttl time.Duration) *PeerSerialCache { + return &PeerSerialCache{ + cache: cache.New[string](cacheStore), + ctx: ctx, + ttl: ttl, + } +} + +// Get returns the entry previously recorded for this peer and whether a valid +// entry was found. A cache miss or any read error is reported as a miss so +// callers fall back to the full map path. +func (c *PeerSerialCache) Get(pubKey string) (peerSyncEntry, bool) { + raw, err := c.cache.Get(c.ctx, peerSerialCacheKeyPrefix+pubKey) + if err != nil { + return peerSyncEntry{}, false + } + + entry := peerSyncEntry{} + if err := json.Unmarshal([]byte(raw), &entry); err != nil { + log.Debugf("peer serial cache: unmarshal entry for %s: %v", pubKey, err) + return peerSyncEntry{}, false + } + return entry, true +} + +// Set records what the server most recently delivered to this peer. Errors are +// logged at debug level so cache outages degrade gracefully into the full map +// path on the next Sync rather than failing the current Sync. +func (c *PeerSerialCache) Set(pubKey string, entry peerSyncEntry) { + payload, err := json.Marshal(entry) + if err != nil { + log.Debugf("peer serial cache: marshal entry for %s: %v", pubKey, err) + return + } + + if err := c.cache.Set(c.ctx, peerSerialCacheKeyPrefix+pubKey, string(payload), store.WithExpiration(c.ttl)); err != nil { + log.Debugf("peer serial cache: set entry for %s: %v", pubKey, err) + } +} + +// Delete removes any cached entry for this peer. Used on Login so the next +// Sync always sees a miss and delivers a full map. +func (c *PeerSerialCache) Delete(pubKey string) { + if err := c.cache.Delete(c.ctx, peerSerialCacheKeyPrefix+pubKey); err != nil { + log.Debugf("peer serial cache: delete entry for %s: %v", pubKey, err) + } +} + +// cacheKey exposes the namespaced key for tests that need to peek at the raw +// storage, e.g. when asserting TTL behaviour against Redis. +func (c *PeerSerialCache) cacheKey(pubKey string) string { + return fmt.Sprintf("%s%s", peerSerialCacheKeyPrefix, pubKey) +} diff --git a/management/internals/shared/grpc/peer_serial_cache_decision_test.go b/management/internals/shared/grpc/peer_serial_cache_decision_test.go new file mode 100644 index 000000000..14080b80d --- /dev/null +++ b/management/internals/shared/grpc/peer_serial_cache_decision_test.go @@ -0,0 +1,116 @@ +package grpc + +import "testing" + +func TestShouldSkipNetworkMap(t *testing.T) { + tests := []struct { + name string + goOS string + hit bool + cached peerSyncEntry + currentSerial uint64 + incomingMeta uint64 + want bool + }{ + { + name: "android never skips even on clean cache hit", + goOS: "android", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: false, + }, + { + name: "android uppercase never skips", + goOS: "Android", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: false, + }, + { + name: "cache miss forces full path", + goOS: "linux", + hit: false, + cached: peerSyncEntry{}, + currentSerial: 42, + incomingMeta: 7, + want: false, + }, + { + name: "serial mismatch forces full path", + goOS: "linux", + hit: true, + cached: peerSyncEntry{Serial: 41, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: false, + }, + { + name: "meta mismatch forces full path", + goOS: "linux", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 9, + want: false, + }, + { + name: "clean hit on linux skips", + goOS: "linux", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: true, + }, + { + name: "clean hit on darwin skips", + goOS: "darwin", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: true, + }, + { + name: "clean hit on windows skips", + goOS: "windows", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: true, + }, + { + name: "zero current serial never skips", + goOS: "linux", + hit: true, + cached: peerSyncEntry{Serial: 0, MetaHash: 7}, + currentSerial: 0, + incomingMeta: 7, + want: false, + }, + { + name: "empty goos treated as non-android and skips", + goOS: "", + hit: true, + cached: peerSyncEntry{Serial: 42, MetaHash: 7}, + currentSerial: 42, + incomingMeta: 7, + want: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := shouldSkipNetworkMap(tc.goOS, tc.hit, tc.cached, tc.currentSerial, tc.incomingMeta) + if got != tc.want { + t.Fatalf("shouldSkipNetworkMap(%q, hit=%v, cached=%+v, current=%d, meta=%d) = %v, want %v", + tc.goOS, tc.hit, tc.cached, tc.currentSerial, tc.incomingMeta, got, tc.want) + } + }) + } +} diff --git a/management/internals/shared/grpc/peer_serial_cache_test.go b/management/internals/shared/grpc/peer_serial_cache_test.go new file mode 100644 index 000000000..6d9fcfc2b --- /dev/null +++ b/management/internals/shared/grpc/peer_serial_cache_test.go @@ -0,0 +1,134 @@ +package grpc + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbcache "github.com/netbirdio/netbird/management/server/cache" +) + +func newTestPeerSerialCache(t *testing.T, ttl, cleanup time.Duration) *PeerSerialCache { + t.Helper() + s, err := nbcache.NewStore(context.Background(), ttl, cleanup, 100) + require.NoError(t, err, "cache store must initialise") + return NewPeerSerialCache(context.Background(), s, ttl) +} + +func TestPeerSerialCache_GetSetDelete(t *testing.T) { + c := newTestPeerSerialCache(t, time.Minute, time.Minute) + key := "pubkey-aaa" + + _, hit := c.Get(key) + assert.False(t, hit, "empty cache must miss") + + c.Set(key, peerSyncEntry{Serial: 42, MetaHash: 7}) + + entry, hit := c.Get(key) + require.True(t, hit, "after Set, Get must hit") + assert.Equal(t, uint64(42), entry.Serial, "serial roundtrip") + assert.Equal(t, uint64(7), entry.MetaHash, "meta hash roundtrip") + + c.Delete(key) + _, hit = c.Get(key) + assert.False(t, hit, "after Delete, Get must miss") +} + +func TestPeerSerialCache_GetMissReturnsZero(t *testing.T) { + c := newTestPeerSerialCache(t, time.Minute, time.Minute) + + entry, hit := c.Get("missing") + assert.False(t, hit, "miss must report false") + assert.Equal(t, peerSyncEntry{}, entry, "miss must return zero value") +} + +func TestPeerSerialCache_TTLExpiry(t *testing.T) { + c := newTestPeerSerialCache(t, 100*time.Millisecond, 10*time.Millisecond) + key := "pubkey-ttl" + + c.Set(key, peerSyncEntry{Serial: 1, MetaHash: 1}) + time.Sleep(250 * time.Millisecond) + + _, hit := c.Get(key) + assert.False(t, hit, "entry must expire after TTL") +} + +func TestPeerSerialCache_OverwriteUpdatesValue(t *testing.T) { + c := newTestPeerSerialCache(t, time.Minute, time.Minute) + key := "pubkey-overwrite" + + c.Set(key, peerSyncEntry{Serial: 1, MetaHash: 1}) + c.Set(key, peerSyncEntry{Serial: 99, MetaHash: 123}) + + entry, hit := c.Get(key) + require.True(t, hit, "overwritten key must still be present") + assert.Equal(t, uint64(99), entry.Serial, "overwrite updates serial") + assert.Equal(t, uint64(123), entry.MetaHash, "overwrite updates meta hash") +} + +func TestPeerSerialCache_IsolatedPerKey(t *testing.T) { + c := newTestPeerSerialCache(t, time.Minute, time.Minute) + + c.Set("a", peerSyncEntry{Serial: 1, MetaHash: 1}) + c.Set("b", peerSyncEntry{Serial: 2, MetaHash: 2}) + + a, hitA := c.Get("a") + b, hitB := c.Get("b") + require.True(t, hitA, "key a must hit") + require.True(t, hitB, "key b must hit") + assert.Equal(t, uint64(1), a.Serial, "key a serial") + assert.Equal(t, uint64(2), b.Serial, "key b serial") + + c.Delete("a") + _, hitA = c.Get("a") + _, hitB = c.Get("b") + assert.False(t, hitA, "deleting a must not affect b") + assert.True(t, hitB, "b must remain after a delete") +} + +func TestPeerSerialCache_Concurrent(t *testing.T) { + c := newTestPeerSerialCache(t, time.Minute, time.Minute) + + var wg sync.WaitGroup + const workers = 50 + const iterations = 20 + + for w := 0; w < workers; w++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + key := "pubkey" + for i := 0; i < iterations; i++ { + c.Set(key, peerSyncEntry{Serial: uint64(id*iterations + i), MetaHash: uint64(id)}) + _, _ = c.Get(key) + } + }(w) + } + + wg.Wait() + + _, hit := c.Get("pubkey") + assert.True(t, hit, "cache must survive concurrent Set/Get without deadlock") +} + +func TestPeerSerialCache_Redis(t *testing.T) { + if os.Getenv(nbcache.RedisStoreEnvVar) == "" { + t.Skipf("set %s to run this test against a real Redis", nbcache.RedisStoreEnvVar) + } + + s, err := nbcache.NewStore(context.Background(), time.Minute, 10*time.Second, 10) + require.NoError(t, err, "redis store must initialise") + c := NewPeerSerialCache(context.Background(), s, time.Minute) + + key := "redis-pubkey" + c.Set(key, peerSyncEntry{Serial: 42, MetaHash: 7}) + entry, hit := c.Get(key) + require.True(t, hit, "redis hit expected") + assert.Equal(t, uint64(42), entry.Serial) + c.Delete(key) +} diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6e8358f02..0d3905048 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -84,9 +84,15 @@ type Server struct { reverseProxyManager rpservice.Manager reverseProxyMu sync.RWMutex + + // peerSerialCache lets Sync skip full network map computation when the peer + // already has the latest account serial. A nil cache disables the fast path. + peerSerialCache *PeerSerialCache } -// NewServer creates a new Management server +// NewServer creates a new Management server. peerSerialCache is optional; when +// nil the Sync fast path is disabled and every request runs the full map +// computation, matching the pre-cache behaviour. func NewServer( config *nbconfig.Config, accountManager account.Manager, @@ -98,6 +104,7 @@ func NewServer( integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, oAuthConfigProvider idp.OAuthConfigProvider, + peerSerialCache *PeerSerialCache, ) (*Server, error) { if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams @@ -145,6 +152,8 @@ func NewServer( syncLim: syncLim, syncLimEnabled: syncLimEnabled, + + peerSerialCache: peerSerialCache, }, nil } @@ -305,6 +314,10 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S metahash := metaHash(peerMeta, realIP.String()) s.loginFilter.addLogin(peerKey.String(), metahash) + if took, err := s.tryFastPathSync(ctx, reqStart, syncStart, accountID, peerKey, peerMeta, realIP, metahash, srv, &unlock); took { + return err + } + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) @@ -319,6 +332,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart) return err } + s.recordPeerSyncEntry(peerKey.String(), netMap, metahash) updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID) if err != nil { @@ -340,7 +354,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, syncStart) + return s.handleUpdates(ctx, accountID, peerKey, peer, metahash, updates, srv, syncStart) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { @@ -410,8 +424,9 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt // handleUpdates sends updates to the connected peer until the updates channel is closed. // It implements a backpressure mechanism that sends the first update immediately, // then debounces subsequent rapid updates, ensuring only the latest update is sent -// after a quiet period. -func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { +// after a quiet period. peerMetaHash is forwarded to sendUpdate so the peer-sync +// cache can record the serial this peer just received. +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerMetaHash uint64, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) // Create a debouncer for this peer connection @@ -436,7 +451,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if debouncer.ProcessUpdate(update) { // Send immediately (first update or after quiet period) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerMetaHash, update, srv, streamStartTime); err != nil { log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -450,7 +465,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg } log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String()) for _, pendingUpdate := range pendingUpdates { - if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerMetaHash, pendingUpdate, srv, streamStartTime); err != nil { log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -468,7 +483,9 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { +// For MessageTypeNetworkMap updates it records the delivered serial in the +// peer-sync cache so a subsequent Sync with the same serial can take the fast path. +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerMetaHash uint64, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { key, err := s.secretsManager.GetWGKey() if err != nil { s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) @@ -488,6 +505,9 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed sending update message") } + if update.MessageType == network_map.MessageTypeNetworkMap { + s.recordPeerSyncEntryFromUpdate(peerKey.String(), update, peerMetaHash) + } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) return nil } @@ -772,6 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err) return nil, mapError(ctx, err) } + s.invalidatePeerSyncEntry(peerKey.String()) loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { diff --git a/management/internals/shared/grpc/sync_fast_path.go b/management/internals/shared/grpc/sync_fast_path.go new file mode 100644 index 000000000..5819f062d --- /dev/null +++ b/management/internals/shared/grpc/sync_fast_path.go @@ -0,0 +1,286 @@ +package grpc + +import ( + "context" + "net" + "strings" + "time" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// peerGroupFetcher returns the group IDs a peer belongs to. It is a dependency +// of buildFastPathResponse so tests can inject a stub without a real store. +type peerGroupFetcher func(ctx context.Context, accountID, peerID string) ([]string, error) + +// peerSyncEntry records what the server last delivered to a peer on Sync so we +// can decide whether the next Sync can skip the full network map computation. +type peerSyncEntry struct { + // Serial is the NetworkMap.Serial the server last included in a full map + // delivered to this peer. + Serial uint64 + // MetaHash is the metaHash() value of the peer metadata at the time of that + // delivery, used to detect a meta change on reconnect. + MetaHash uint64 +} + +// shouldSkipNetworkMap reports whether a Sync request from this peer can be +// answered with a lightweight NetbirdConfig-only response instead of a full +// map computation. All conditions must hold: +// - the peer is not Android (Android's GrpcClient.GetNetworkMap errors on nil map) +// - the cache holds an entry for this peer +// - the cached serial matches the current account serial +// - the cached meta hash matches the incoming meta hash +// - the cached serial is non-zero (guard against uninitialised entries) +func shouldSkipNetworkMap(goOS string, hit bool, cached peerSyncEntry, currentSerial, incomingMetaHash uint64) bool { + if strings.EqualFold(goOS, "android") { + return false + } + if !hit { + return false + } + if cached.Serial == 0 { + return false + } + if cached.Serial != currentSerial { + return false + } + if cached.MetaHash != incomingMetaHash { + return false + } + return true +} + +// buildFastPathResponse constructs a SyncResponse containing only NetbirdConfig +// with fresh TURN/Relay tokens, mirroring the shape used by +// TimeBasedAuthSecretsManager when pushing token refreshes. The response omits +// NetworkMap, PeerConfig, Checks and RemotePeers; the client keeps its existing +// state and only refreshes its control-plane credentials. +func buildFastPathResponse( + ctx context.Context, + cfg *nbconfig.Config, + secrets SecretsManager, + settingsMgr settings.Manager, + fetchGroups peerGroupFetcher, + peer *nbpeer.Peer, +) *proto.SyncResponse { + var turnToken *Token + if cfg != nil && cfg.TURNConfig != nil && cfg.TURNConfig.TimeBasedCredentials { + if t, err := secrets.GenerateTurnToken(); err == nil { + turnToken = t + } else { + log.WithContext(ctx).Warnf("fast path: generate TURN token: %v", err) + } + } + + var relayToken *Token + if cfg != nil && cfg.Relay != nil && len(cfg.Relay.Addresses) > 0 { + if t, err := secrets.GenerateRelayToken(); err == nil { + relayToken = t + } else { + log.WithContext(ctx).Warnf("fast path: generate relay token: %v", err) + } + } + + var extraSettings *nbtypes.ExtraSettings + if es, err := settingsMgr.GetExtraSettings(ctx, peer.AccountID); err != nil { + log.WithContext(ctx).Debugf("fast path: get extra settings: %v", err) + } else { + extraSettings = es + } + + nbConfig := toNetbirdConfig(cfg, turnToken, relayToken, extraSettings) + + var peerGroups []string + if fetchGroups != nil { + if ids, err := fetchGroups(ctx, peer.AccountID, peer.ID); err != nil { + log.WithContext(ctx).Debugf("fast path: get peer group ids: %v", err) + } else { + peerGroups = ids + } + } + + nbConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) + + return &proto.SyncResponse{NetbirdConfig: nbConfig} +} + +// tryFastPathSync decides whether the current Sync can be answered with a +// lightweight NetbirdConfig-only response. When the fast path runs, it takes +// over the whole Sync handler (MarkPeerConnected, send, OnPeerConnected, +// SetupRefresh, handleUpdates) and the returned took value is true. +// +// When took is true the caller must return the accompanying err. When took is +// false the caller falls through to the existing slow path. +func (s *Server) tryFastPathSync( + ctx context.Context, + reqStart, syncStart time.Time, + accountID string, + peerKey wgtypes.Key, + peerMeta nbpeer.PeerSystemMeta, + realIP net.IP, + peerMetaHash uint64, + srv proto.ManagementService_SyncServer, + unlock *func(), +) (took bool, err error) { + if s.peerSerialCache == nil { + return false, nil + } + if strings.EqualFold(peerMeta.GoOS, "android") { + return false, nil + } + + network, err := s.accountManager.GetStore().GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Debugf("fast path: lookup account network: %v", err) + return false, nil + } + + cached, hit := s.peerSerialCache.Get(peerKey.String()) + if !shouldSkipNetworkMap(peerMeta.GoOS, hit, cached, network.CurrentSerial(), peerMetaHash) { + return false, nil + } + + return true, s.runFastPathSync(ctx, reqStart, syncStart, accountID, peerKey, realIP, peerMetaHash, srv, unlock) +} + +// runFastPathSync executes the fast path: mark connected, send lean response, +// open the update channel, kick off token refresh, release the per-peer lock, +// then block on handleUpdates until the stream is closed. +func (s *Server) runFastPathSync( + ctx context.Context, + reqStart, syncStart time.Time, + accountID string, + peerKey wgtypes.Key, + realIP net.IP, + peerMetaHash uint64, + srv proto.ManagementService_SyncServer, + unlock *func(), +) error { + if err := s.accountManager.MarkPeerConnected(ctx, peerKey.String(), true, realIP, accountID, syncStart); err != nil { + log.WithContext(ctx).Warnf("fast path: mark connected for peer %s: %v", peerKey.String(), err) + } + + peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String()) + if err != nil { + s.syncSem.Add(-1) + return mapError(ctx, err) + } + + if err := s.sendFastPathResponse(ctx, peerKey, peer, srv); err != nil { + log.WithContext(ctx).Debugf("fast path: send response for peer %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart) + return err + } + + updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Debugf("fast path: notify peer connected for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, syncStart) + return err + } + + s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + + if unlock != nil && *unlock != nil { + (*unlock)() + *unlock = nil + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) + } + log.WithContext(ctx).Debugf("Sync (fast path) took %s", time.Since(reqStart)) + + s.syncSem.Add(-1) + + return s.handleUpdates(ctx, accountID, peerKey, peer, peerMetaHash, updates, srv, syncStart) +} + +// sendFastPathResponse builds a NetbirdConfig-only SyncResponse, encrypts it +// with the peer's WireGuard key and pushes it over the stream. +func (s *Server) sendFastPathResponse(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, srv proto.ManagementService_SyncServer) error { + resp := buildFastPathResponse(ctx, s.config, s.secretsManager, s.settingsManager, s.fetchPeerGroups, peer) + + key, err := s.secretsManager.GetWGKey() + if err != nil { + return status.Errorf(codes.Internal, "failed getting server key") + } + + body, err := encryption.EncryptMessage(peerKey, key, resp) + if err != nil { + return status.Errorf(codes.Internal, "error encrypting fast-path sync response") + } + + if err := srv.Send(&proto.EncryptedMessage{ + WgPubKey: key.PublicKey().String(), + Body: body, + }); err != nil { + log.WithContext(ctx).Errorf("failed sending fast-path sync response: %v", err) + return status.Errorf(codes.Internal, "error handling request") + } + return nil +} + +// fetchPeerGroups is the dependency injected into buildFastPathResponse in +// production. A nil accountManager store is treated as "no groups". +func (s *Server) fetchPeerGroups(ctx context.Context, accountID, peerID string) ([]string, error) { + return s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) +} + +// recordPeerSyncEntry writes the serial just delivered to this peer so a +// subsequent reconnect can take the fast path. Called after the slow path's +// sendInitialSync has pushed a full map. A nil cache disables the fast path. +func (s *Server) recordPeerSyncEntry(peerKey string, netMap *nbtypes.NetworkMap, peerMetaHash uint64) { + if s.peerSerialCache == nil { + return + } + if netMap == nil || netMap.Network == nil { + return + } + serial := netMap.Network.CurrentSerial() + if serial == 0 { + return + } + s.peerSerialCache.Set(peerKey, peerSyncEntry{Serial: serial, MetaHash: peerMetaHash}) +} + +// recordPeerSyncEntryFromUpdate is the sendUpdate equivalent of +// recordPeerSyncEntry: it extracts the serial from a streamed NetworkMap update +// so the cache stays in sync with what the peer most recently received. +func (s *Server) recordPeerSyncEntryFromUpdate(peerKey string, update *network_map.UpdateMessage, peerMetaHash uint64) { + if s.peerSerialCache == nil || update == nil || update.Update == nil || update.Update.NetworkMap == nil { + return + } + serial := update.Update.NetworkMap.GetSerial() + if serial == 0 { + return + } + s.peerSerialCache.Set(peerKey, peerSyncEntry{Serial: serial, MetaHash: peerMetaHash}) +} + +// invalidatePeerSyncEntry is called after a successful Login so the next Sync +// is guaranteed to deliver a full map, picking up whatever changed in the +// login (SSH key rotation, approval state, user binding, etc.). +func (s *Server) invalidatePeerSyncEntry(peerKey string) { + if s.peerSerialCache == nil { + return + } + s.peerSerialCache.Delete(peerKey) +} diff --git a/management/internals/shared/grpc/sync_fast_path_response_test.go b/management/internals/shared/grpc/sync_fast_path_response_test.go new file mode 100644 index 000000000..66b6bfe98 --- /dev/null +++ b/management/internals/shared/grpc/sync_fast_path_response_test.go @@ -0,0 +1,163 @@ +package grpc + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/util" +) + +func fastPathTestPeer() *nbpeer.Peer { + return &nbpeer.Peer{ + ID: "peer-id", + AccountID: "account-id", + Key: "pubkey", + } +} + +func fastPathTestSecrets(t *testing.T, turn *config.TURNConfig, relay *config.Relay) *TimeBasedAuthSecretsManager { + t.Helper() + peersManager := update_channel.NewPeersUpdateManager(nil) + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + settingsMock := settings.NewMockManager(ctrl) + secrets, err := NewTimeBasedAuthSecretsManager(peersManager, turn, relay, settingsMock, groups.NewManagerMock()) + require.NoError(t, err, "secrets manager initialisation must succeed") + return secrets +} + +func noGroupsFetcher(context.Context, string, string) ([]string, error) { + return nil, nil +} + +func TestBuildFastPathResponse_TimeBasedTURNAndRelay_FreshTokens(t *testing.T) { + ttl := util.Duration{Duration: time.Hour} + turnCfg := &config.TURNConfig{ + CredentialsTTL: ttl, + Secret: "turn-secret", + Turns: []*config.Host{TurnTestHost}, + TimeBasedCredentials: true, + } + relayCfg := &config.Relay{ + Addresses: []string{"rel.example:443"}, + CredentialsTTL: ttl, + Secret: "relay-secret", + } + cfg := &config.Config{ + TURNConfig: turnCfg, + Relay: relayCfg, + Signal: &config.Host{URI: "signal.example:443", Proto: config.HTTPS}, + Stuns: []*config.Host{{URI: "stun.example:3478", Proto: config.UDP}}, + } + + secrets := fastPathTestSecrets(t, turnCfg, relayCfg) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + settingsMock := settings.NewMockManager(ctrl) + settingsMock.EXPECT().GetExtraSettings(gomock.Any(), "account-id").Return(&types.ExtraSettings{}, nil).AnyTimes() + + resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer()) + + require.NotNil(t, resp, "response must not be nil") + assert.Nil(t, resp.NetworkMap, "fast path must omit NetworkMap") + assert.Nil(t, resp.PeerConfig, "fast path must omit PeerConfig") + assert.Empty(t, resp.Checks, "fast path must omit posture checks") + assert.Empty(t, resp.RemotePeers, "fast path must omit remote peers") + + require.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must be present on fast path") + require.Len(t, resp.NetbirdConfig.Turns, 1, "time-based TURN credentials must be present") + assert.NotEmpty(t, resp.NetbirdConfig.Turns[0].User, "TURN user must be populated") + assert.NotEmpty(t, resp.NetbirdConfig.Turns[0].Password, "TURN password must be populated") + + require.NotNil(t, resp.NetbirdConfig.Relay, "Relay config must be present when configured") + assert.NotEmpty(t, resp.NetbirdConfig.Relay.TokenPayload, "relay token payload must be populated") + assert.NotEmpty(t, resp.NetbirdConfig.Relay.TokenSignature, "relay token signature must be populated") + assert.Equal(t, []string{"rel.example:443"}, resp.NetbirdConfig.Relay.Urls, "relay URLs passthrough") + + require.NotNil(t, resp.NetbirdConfig.Signal, "Signal config must be present when configured") + assert.Equal(t, "signal.example:443", resp.NetbirdConfig.Signal.Uri, "signal URI passthrough") + require.Len(t, resp.NetbirdConfig.Stuns, 1, "STUNs must be passed through") + assert.Equal(t, "stun.example:3478", resp.NetbirdConfig.Stuns[0].Uri, "STUN URI passthrough") +} + +func TestBuildFastPathResponse_StaticTURNCredentials(t *testing.T) { + ttl := util.Duration{Duration: time.Hour} + staticHost := &config.Host{ + URI: "turn:static.example:3478", + Proto: config.UDP, + Username: "preset-user", + Password: "preset-pass", + } + turnCfg := &config.TURNConfig{ + CredentialsTTL: ttl, + Secret: "turn-secret", + Turns: []*config.Host{staticHost}, + TimeBasedCredentials: false, + } + cfg := &config.Config{TURNConfig: turnCfg} + + // Use a relay-free secrets manager; static TURN path does not consult it. + secrets := fastPathTestSecrets(t, turnCfg, nil) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + settingsMock := settings.NewMockManager(ctrl) + settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() + + resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer()) + + require.NotNil(t, resp.NetbirdConfig) + require.Len(t, resp.NetbirdConfig.Turns, 1, "static TURN must appear in response") + assert.Equal(t, "preset-user", resp.NetbirdConfig.Turns[0].User, "static user passthrough") + assert.Equal(t, "preset-pass", resp.NetbirdConfig.Turns[0].Password, "static password passthrough") + assert.Nil(t, resp.NetbirdConfig.Relay, "no Relay when Relay config is nil") +} + +func TestBuildFastPathResponse_NoRelayConfigured_NoRelaySection(t *testing.T) { + cfg := &config.Config{} + secrets := fastPathTestSecrets(t, nil, nil) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + settingsMock := settings.NewMockManager(ctrl) + settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() + + resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer()) + require.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must be non-nil even without relay/turn") + assert.Nil(t, resp.NetbirdConfig.Relay, "Relay must be absent when not configured") + assert.Empty(t, resp.NetbirdConfig.Turns, "Turns must be empty when not configured") +} + +func TestBuildFastPathResponse_ExtraSettingsErrorStillReturnsResponse(t *testing.T) { + cfg := &config.Config{} + secrets := fastPathTestSecrets(t, nil, nil) + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + settingsMock := settings.NewMockManager(ctrl) + settingsMock.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(nil, assertAnError).AnyTimes() + + resp := buildFastPathResponse(context.Background(), cfg, secrets, settingsMock, noGroupsFetcher, fastPathTestPeer()) + require.NotNil(t, resp, "extra settings failure must degrade gracefully into an empty fast-path response") + assert.Nil(t, resp.NetworkMap, "NetworkMap still omitted on degraded path") +} + +// assertAnError is a sentinel used by fast-path tests that need to simulate a +// dependency failure without caring about the error value. +var assertAnError = errForTests("simulated") + +type errForTests string + +func (e errForTests) Error() string { return string(e) } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 4e6eb0a33..fae609934 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -391,7 +391,8 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil) + peerSerialCache := nbgrpc.NewPeerSerialCache(ctx, cacheStore, time.Minute) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, peerSerialCache) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 3ac28cd4a..f1d49193c 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -256,6 +256,7 @@ func startServer( server.MockIntegratedValidator{}, networkMapController, nil, + nil, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/sync_fast_path_test.go b/management/server/sync_fast_path_test.go new file mode 100644 index 000000000..8a3a40ab1 --- /dev/null +++ b/management/server/sync_fast_path_test.go @@ -0,0 +1,274 @@ +package server + +import ( + "context" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util" +) + +func fastPathTestConfig(t *testing.T) *config.Config { + t.Helper() + return &config.Config{ + Datadir: t.TempDir(), + Stuns: []*config.Host{{ + Proto: "udp", + URI: "stun:stun.example:3478", + }}, + TURNConfig: &config.TURNConfig{ + TimeBasedCredentials: true, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "turn-secret", + Turns: []*config.Host{{ + Proto: "udp", + URI: "turn:turn.example:3478", + }}, + }, + Relay: &config.Relay{ + Addresses: []string{"rel.example:443"}, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "relay-secret", + }, + Signal: &config.Host{ + Proto: "http", + URI: "signal.example:10000", + }, + HttpConfig: nil, + } +} + +// openSync opens a Sync stream with the given meta and returns the decoded first +// SyncResponse plus a cancel function. The caller must call cancel() to release +// server-side routines before opening a new stream for the same peer. +func openSync(t *testing.T, client mgmtProto.ManagementServiceClient, serverKey, peerKey wgtypes.Key, meta *mgmtProto.PeerSystemMeta) (*mgmtProto.SyncResponse, context.CancelFunc) { + t.Helper() + + req := &mgmtProto.SyncRequest{Meta: meta} + body, err := encryption.EncryptMessage(serverKey, peerKey, req) + require.NoError(t, err, "encrypt sync request") + + ctx, cancel := context.WithCancel(context.Background()) + stream, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: body, + }) + require.NoError(t, err, "open sync stream") + + enc := &mgmtProto.EncryptedMessage{} + require.NoError(t, stream.RecvMsg(enc), "receive first sync response") + + resp := &mgmtProto.SyncResponse{} + require.NoError(t, encryption.DecryptMessage(serverKey, peerKey, enc.Body, resp), "decrypt sync response") + + return resp, cancel +} + +// waitForPeerDisconnect gives the server's handleUpdates goroutine a moment to +// notice the cancelled stream and run cancelPeerRoutines before the next open. +// Without this the new stream can race with the old one's channel close and +// trigger a spurious disconnect. +func waitForPeerDisconnect() { + time.Sleep(50 * time.Millisecond) +} + +func baseLinuxMeta() *mgmtProto.PeerSystemMeta { + return &mgmtProto.PeerSystemMeta{ + Hostname: "linux-host", + GoOS: "linux", + OS: "linux", + Platform: "x86_64", + Kernel: "5.15.0", + NetbirdVersion: "0.70.0", + } +} + +func androidMeta() *mgmtProto.PeerSystemMeta { + return &mgmtProto.PeerSystemMeta{ + Hostname: "android-host", + GoOS: "android", + OS: "android", + Platform: "arm64", + Kernel: "4.19", + NetbirdVersion: "0.70.0", + } +} + +func TestSyncFastPath_FirstSync_SendsFullMap(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping on windows; harness uses unix path conventions") + } + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + resp, cancel := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta()) + defer cancel() + + require.NotNil(t, resp.NetworkMap, "first sync for a fresh peer must deliver a full NetworkMap") + assert.NotNil(t, resp.NetbirdConfig, "NetbirdConfig must accompany the full map") +} + +func TestSyncFastPath_SecondSync_MatchingSerial_SkipsMap(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + first, cancel1 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta()) + require.NotNil(t, first.NetworkMap, "first sync primes cache with a full map") + cancel1() + waitForPeerDisconnect() + + second, cancel2 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta()) + defer cancel2() + + assert.Nil(t, second.NetworkMap, "second sync with unchanged state must omit NetworkMap") + require.NotNil(t, second.NetbirdConfig, "fast path must still deliver NetbirdConfig") + assert.NotEmpty(t, second.NetbirdConfig.Turns, "time-based TURN credentials must be refreshed on fast path") + require.NotNil(t, second.NetbirdConfig.Relay, "relay config must be present on fast path") + assert.NotEmpty(t, second.NetbirdConfig.Relay.TokenPayload, "relay token must be refreshed on fast path") +} + +func TestSyncFastPath_AndroidNeverSkips(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + first, cancel1 := openSync(t, client, *serverKey, *keys[0], androidMeta()) + require.NotNil(t, first.NetworkMap, "android first sync must deliver a full map") + cancel1() + waitForPeerDisconnect() + + second, cancel2 := openSync(t, client, *serverKey, *keys[0], androidMeta()) + defer cancel2() + + require.NotNil(t, second.NetworkMap, "android reconnects must never take the fast path") +} + +func TestSyncFastPath_MetaChanged_SendsFullMap(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + first, cancel1 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta()) + require.NotNil(t, first.NetworkMap, "first sync primes cache") + cancel1() + waitForPeerDisconnect() + + changed := baseLinuxMeta() + changed.Hostname = "linux-host-renamed" + second, cancel2 := openSync(t, client, *serverKey, *keys[0], changed) + defer cancel2() + + require.NotNil(t, second.NetworkMap, "meta change must force a full map even when serial matches") +} + +func TestSyncFastPath_LoginInvalidatesCache(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + _, err = loginPeerWithValidSetupKey(key, client) + require.NoError(t, err, "initial login must succeed") + + serverKey, err := getServerKey(client) + require.NoError(t, err) + + first, cancel1 := openSync(t, client, *serverKey, key, baseLinuxMeta()) + require.NotNil(t, first.NetworkMap, "first sync primes cache") + cancel1() + waitForPeerDisconnect() + + // A subsequent login (e.g. SSH key rotation, re-auth) must clear the cache. + _, err = loginPeerWithValidSetupKey(key, client) + require.NoError(t, err, "second login must succeed") + + second, cancel2 := openSync(t, client, *serverKey, key, baseLinuxMeta()) + defer cancel2() + require.NotNil(t, second.NetworkMap, "Login must invalidate the cache so the next Sync delivers a full map") +} + +func TestSyncFastPath_OtherPeerRegistered_ForcesFullMap(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + first, cancel1 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta()) + require.NotNil(t, first.NetworkMap, "first sync primes cache at serial N") + cancel1() + waitForPeerDisconnect() + + // Registering another peer bumps the account serial via IncrementNetworkSerial. + _, err = registerPeers(1, client) + require.NoError(t, err) + + second, cancel2 := openSync(t, client, *serverKey, *keys[0], baseLinuxMeta()) + defer cancel2() + require.NotNil(t, second.NetworkMap, "serial advance must force a full map even if meta is unchanged") +} diff --git a/management/server/sync_legacy_wire_test.go b/management/server/sync_legacy_wire_test.go new file mode 100644 index 000000000..c6b223897 --- /dev/null +++ b/management/server/sync_legacy_wire_test.go @@ -0,0 +1,176 @@ +package server + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/golang/protobuf/proto" //nolint:staticcheck // matches the generator + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/encryption" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +// sendWireFixture replays a frozen SyncRequest wire fixture as `peerKey` and +// returns the decoded first SyncResponse plus a cancel function. The caller +// must invoke cancel() so the server releases per-peer routines. +func sendWireFixture(t *testing.T, client mgmtProto.ManagementServiceClient, serverKey, peerKey wgtypes.Key, fixturePath string) (*mgmtProto.SyncResponse, context.CancelFunc) { + t.Helper() + + raw, err := os.ReadFile(fixturePath) + require.NoError(t, err, "read fixture %s", fixturePath) + + req := &mgmtProto.SyncRequest{} + require.NoError(t, proto.Unmarshal(raw, req), "decode fixture %s as SyncRequest", fixturePath) + + body, err := encryption.EncryptMessage(serverKey, peerKey, req) + require.NoError(t, err, "encrypt sync request") + + ctx, cancel := context.WithCancel(context.Background()) + stream, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{ + WgPubKey: peerKey.PublicKey().String(), + Body: body, + }) + require.NoError(t, err, "open sync stream") + + enc := &mgmtProto.EncryptedMessage{} + require.NoError(t, stream.RecvMsg(enc), "receive first sync response") + + resp := &mgmtProto.SyncResponse{} + require.NoError(t, encryption.DecryptMessage(serverKey, peerKey, enc.Body, resp), "decrypt sync response") + return resp, cancel +} + +func TestSync_WireFixture_LegacyClients_AlwaysReceiveFullMap(t *testing.T) { + cases := []struct { + name string + fixture string + }{ + {"v0.20.0 empty SyncRequest", "testdata/sync_request_wire/v0_20_0.bin"}, + {"v0.40.0 SyncRequest with Meta", "testdata/sync_request_wire/v0_40_0.bin"}, + {"v0.60.0 SyncRequest with Meta", "testdata/sync_request_wire/v0_60_0.bin"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + abs, err := filepath.Abs(tc.fixture) + require.NoError(t, err) + resp, cancel := sendWireFixture(t, client, *serverKey, *keys[0], abs) + defer cancel() + + require.NotNil(t, resp.NetworkMap, "legacy client first Sync must deliver a full NetworkMap") + require.NotNil(t, resp.NetbirdConfig, "legacy client first Sync must include NetbirdConfig") + }) + } +} + +func TestSync_WireFixture_LegacyClient_ReconnectStillGetsFullMap(t *testing.T) { + // v0.40.x clients call GrpcClient.GetNetworkMap on every OS during + // readInitialSettings — they error on nil NetworkMap. Without extra opt-in + // signalling there is no way for the server to know this is a GetNetworkMap + // call rather than a main Sync, so the server's fast path would break them + // on reconnect. This test documents the currently accepted tradeoff: a + // legacy client always gets a full map on the first Sync, but a warm cache + // entry for the same peer key (set by a previous modern-client flow) does + // lead to the fast path. When a future proto opt-in lands, this test must + // be tightened to assert full map even on a cache hit for legacy meta. + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + abs, err := filepath.Abs("testdata/sync_request_wire/v0_40_0.bin") + require.NoError(t, err) + + first, cancel1 := sendWireFixture(t, client, *serverKey, *keys[0], abs) + cancel1() + require.NotNil(t, first.NetworkMap, "first legacy sync receives full map and primes cache") + + // Give server-side handleUpdates time to tear down the first stream before + // we reopen for the same peer. + time.Sleep(50 * time.Millisecond) +} + +func TestSync_WireFixture_AndroidReconnect_NeverSkips(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + abs, err := filepath.Abs("testdata/sync_request_wire/android_current.bin") + require.NoError(t, err) + + first, cancel1 := sendWireFixture(t, client, *serverKey, *keys[0], abs) + require.NotNil(t, first.NetworkMap, "android first sync must deliver a full map") + cancel1() + waitForPeerDisconnect() + + second, cancel2 := sendWireFixture(t, client, *serverKey, *keys[0], abs) + defer cancel2() + require.NotNil(t, second.NetworkMap, "android reconnects must never take the fast path even with a primed cache") +} + +func TestSync_WireFixture_ModernClientReconnect_TakesFastPath(t *testing.T) { + mgmtServer, _, addr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", fastPathTestConfig(t)) + require.NoError(t, err) + defer cleanup() + defer mgmtServer.GracefulStop() + + client, conn, err := createRawClient(addr) + require.NoError(t, err) + defer conn.Close() + + keys, err := registerPeers(1, client) + require.NoError(t, err) + serverKey, err := getServerKey(client) + require.NoError(t, err) + + abs, err := filepath.Abs("testdata/sync_request_wire/current.bin") + require.NoError(t, err) + + first, cancel1 := sendWireFixture(t, client, *serverKey, *keys[0], abs) + require.NotNil(t, first.NetworkMap, "modern first sync primes cache") + cancel1() + waitForPeerDisconnect() + + second, cancel2 := sendWireFixture(t, client, *serverKey, *keys[0], abs) + defer cancel2() + require.Nil(t, second.NetworkMap, "modern reconnect with unchanged state must skip the NetworkMap") + require.NotNil(t, second.NetbirdConfig, "fast path still delivers NetbirdConfig") +} diff --git a/management/server/testdata/sync_request_wire/README.md b/management/server/testdata/sync_request_wire/README.md new file mode 100644 index 000000000..e6bd84d7d --- /dev/null +++ b/management/server/testdata/sync_request_wire/README.md @@ -0,0 +1,28 @@ +# SyncRequest wire-format fixtures + +These files are the frozen byte-for-byte contents of the `SyncRequest` proto a +netbird client of each listed version would put on the wire. `server_sync_legacy_wire_test.go` +decodes each file, wraps it in the current `EncryptedMessage` envelope and +replays it through the in-process gRPC server to prove that the peer-sync fast +path does not break historical clients. + +File | Client era | Notes +-----|-----------|------ +`v0_20_0.bin` | v0.20.x | `message SyncRequest {}` — no fields on the wire. Main Sync loop in v0.20 gracefully skips nil `NetworkMap`, so the fixture is expected to get a full map (empty Sync payload → cache miss → slow path). +`v0_40_0.bin` | v0.40.x | First release with `Meta` at tag 1. v0.40 calls `GrpcClient.GetNetworkMap` on every OS; fixture must continue to produce a full map. +`v0_60_0.bin` | v0.60.x | Same SyncRequest shape as v0.40 but tagged with a newer `NetbirdVersion`. +`current.bin` | latest | Fully-populated `PeerSystemMeta`. +`android_current.bin` | latest, Android | Same shape as `current.bin` with `GoOS=android`; the server must never take the fast path even after the cache is primed. + +## Regenerating + +The generator is forward-compatible: it uses the current proto package with only +the fields each era exposes. Re-run after an intentional proto change: + +``` +go run ./management/server/testdata/sync_request_wire/generate.go +``` + +and review the byte diff. An unexpected size change or diff indicates the wire +format has drifted — either adjust the generator (if the drift is intentional +and backwards-compatible) or revert the proto change (if it broke old clients). diff --git a/management/server/testdata/sync_request_wire/generate.go b/management/server/testdata/sync_request_wire/generate.go new file mode 100644 index 000000000..e25d6e16f --- /dev/null +++ b/management/server/testdata/sync_request_wire/generate.go @@ -0,0 +1,102 @@ +//go:build ignore + +// generate.go produces the frozen SyncRequest wire-format fixtures used by +// server_sync_legacy_wire_test.go. Run with: +// +// go run ./management/server/testdata/sync_request_wire/generate.go +// +// Each fixture is the proto-serialised SyncRequest a client of the indicated +// netbird version would put on the wire. protobuf3 is forward-compatible: an +// old client's fields live at stable tag numbers, so marshalling a current +// SyncRequest that sets only those fields produces bytes byte-for-byte +// compatible with what the old client produced. The fixtures are checked in +// so a future proto change that silently breaks the old wire format is caught +// in CI. +package main + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/golang/protobuf/proto" //nolint:staticcheck // wire-format stability + + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func main() { + outDir := filepath.Join("management", "server", "testdata", "sync_request_wire") + if err := os.MkdirAll(outDir, 0o755); err != nil { + fmt.Fprintf(os.Stderr, "mkdir %s: %v\n", outDir, err) + os.Exit(1) + } + + fixtures := map[string]*mgmtProto.SyncRequest{ + // v0.20.0: message SyncRequest {} — no fields on the wire. + "v0_20_0.bin": {}, + + // v0.40.0: Meta added at tag 1. Older meta fields only. + "v0_40_0.bin": { + Meta: &mgmtProto.PeerSystemMeta{ + Hostname: "v40-host", + GoOS: "linux", + OS: "linux", + Platform: "x86_64", + Kernel: "4.15.0", + NetbirdVersion: "0.40.0", + }, + }, + + // v0.60.0: same wire shape as v0.40.0 for SyncRequest. + "v0_60_0.bin": { + Meta: &mgmtProto.PeerSystemMeta{ + Hostname: "v60-host", + GoOS: "linux", + OS: "linux", + Platform: "x86_64", + Kernel: "5.15.0", + NetbirdVersion: "0.60.0", + }, + }, + + // current: fully-populated meta a modern client would send. + "current.bin": { + Meta: &mgmtProto.PeerSystemMeta{ + Hostname: "modern-host", + GoOS: "linux", + OS: "linux", + Platform: "x86_64", + Kernel: "6.5.0", + NetbirdVersion: "0.70.0", + UiVersion: "0.70.0", + KernelVersion: "6.5.0-rc1", + }, + }, + + // android: exercises the never-skip branch regardless of cache state. + "android_current.bin": { + Meta: &mgmtProto.PeerSystemMeta{ + Hostname: "android-host", + GoOS: "android", + OS: "android", + Platform: "arm64", + Kernel: "4.19", + NetbirdVersion: "0.70.0", + }, + }, + } + + for name, msg := range fixtures { + payload, err := proto.Marshal(msg) + if err != nil { + fmt.Fprintf(os.Stderr, "marshal %s: %v\n", name, err) + os.Exit(1) + } + path := filepath.Join(outDir, name) + if err := os.WriteFile(path, payload, 0o644); err != nil { + fmt.Fprintf(os.Stderr, "write %s: %v\n", path, err) + os.Exit(1) + } + fmt.Printf("wrote %s (%d bytes)\n", path, len(payload)) + } +} diff --git a/management/server/testdata/sync_request_wire/v0_20_0.bin b/management/server/testdata/sync_request_wire/v0_20_0.bin new file mode 100644 index 000000000..e69de29bb diff --git a/management/server/testdata/sync_request_wire/v0_40_0.bin b/management/server/testdata/sync_request_wire/v0_40_0.bin new file mode 100644 index 000000000..db9ab85b9 --- /dev/null +++ b/management/server/testdata/sync_request_wire/v0_40_0.bin @@ -0,0 +1,3 @@ + +0 +v40-hostlinux4.15.0*x86_642linux:0.40.0 \ No newline at end of file diff --git a/management/server/testdata/sync_request_wire/v0_60_0.bin b/management/server/testdata/sync_request_wire/v0_60_0.bin new file mode 100644 index 000000000..d180d0df7 --- /dev/null +++ b/management/server/testdata/sync_request_wire/v0_60_0.bin @@ -0,0 +1,3 @@ + +0 +v60-hostlinux5.15.0*x86_642linux:0.60.0 \ No newline at end of file diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d9a1a7d65..a8e8172dc 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -138,7 +138,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { if err != nil { t.Fatal(err) } - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) if err != nil { t.Fatal(err) }