diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go index 758f643d0..0f902ea70 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/store" ) @@ -47,6 +48,11 @@ type EphemeralManager struct { lifeTime time.Duration cleanupWindow time.Duration + + // metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics + // no-op when the receiver is nil so deployments without an app + // metrics provider work unchanged. + metrics *telemetry.EphemeralPeersMetrics } // NewEphemeralManager instantiate new EphemeralManager @@ -60,6 +66,15 @@ func NewEphemeralManager(store store.Store, peersManager peers.Manager) *Ephemer } } +// SetMetrics attaches a metrics collector. Safe to call once before +// LoadInitialPeers; later attachment is fine but earlier loads won't be +// reflected in the gauge. Pass nil to detach. +func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) { + e.peersLock.Lock() + e.metrics = m + e.peersLock.Unlock() +} + // LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head // of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new // head. @@ -97,7 +112,9 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee e.peersLock.Lock() defer e.peersLock.Unlock() - e.removePeer(peer.ID) + if e.removePeer(peer.ID) { + e.metrics.DecPending(1) + } // stop the unnecessary timer if e.headPeer == nil && e.timer != nil { @@ -123,6 +140,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } e.addPeer(peer.AccountID, peer.ID, e.newDeadLine()) + e.metrics.IncPending() if e.timer == nil { delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow if delay < 0 { @@ -145,6 +163,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { for _, p := range peers { e.addPeer(p.AccountID, p.ID, t) } + e.metrics.AddPending(int64(len(peers))) log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers)) } @@ -181,6 +200,15 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() + // Drop the gauge by the number of entries we just took off the list, + // regardless of whether the subsequent DeletePeers call succeeds. The + // list invariant is what the gauge tracks; failed delete batches are + // counted separately via CountCleanupError so we can still see them. + if len(deletePeers) > 0 { + e.metrics.CountCleanupRun() + e.metrics.DecPending(int64(len(deletePeers))) + } + peerIDsPerAccount := make(map[string][]string) for id, p := range deletePeers { peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id) @@ -191,7 +219,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err) + e.metrics.CountCleanupError() + continue } + e.metrics.CountPeersCleaned(int64(len(peerIDs))) } } @@ -211,9 +242,12 @@ func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline tim e.tailPeer = ep } -func (e *EphemeralManager) removePeer(id string) { +// removePeer drops the entry from the linked list. Returns true if a +// matching entry was found and removed so callers can keep the pending +// metric gauge in sync. +func (e *EphemeralManager) removePeer(id string) bool { if e.headPeer == nil { - return + return false } if e.headPeer.id == id { @@ -221,7 +255,7 @@ func (e *EphemeralManager) removePeer(id string) { if e.tailPeer.id == id { e.tailPeer = nil } - return + return true } for p := e.headPeer; p.next != nil; p = p.next { @@ -231,9 +265,10 @@ func (e *EphemeralManager) removePeer(id string) { e.tailPeer = p } p.next = p.next.next - return + return true } } + return false } func (e *EphemeralManager) isPeerOnList(id string) bool { diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 89bdf0abe..794c3ebe0 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -112,7 +112,11 @@ func (s *BaseServer) AuthManager() auth.Manager { func (s *BaseServer) EphemeralManager() ephemeral.Manager { return Create(s, func() ephemeral.Manager { - return manager.NewEphemeralManager(s.Store(), s.PeersManager()) + em := manager.NewEphemeralManager(s.Store(), s.PeersManager()) + if metrics := s.Metrics(); metrics != nil { + em.SetMetrics(metrics.EphemeralPeersMetrics()) + } + return em }) } diff --git a/management/server/account.go b/management/server/account.go index e7b4acaac..8e4e595f0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1868,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } +// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's +// network map and then marks the peer connected with a session token +// derived from syncTime (the moment the gRPC stream opened). Any +// concurrent stream that started earlier loses the optimistic-lock race +// in MarkPeerConnected and bails without writing. func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) - if err != nil { + if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return peer, netMap, postureChecks, dnsfwdPort, nil } +// OnPeerDisconnected is invoked when a sync stream ends. It marks the +// peer disconnected only when the stored SessionStartedAt matches the +// nanosecond token derived from streamStartTime — i.e. only when this +// is the stream that currently owns the peer's session. A mismatch +// means a newer stream has already replaced us, so the disconnect is +// dropped. func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) - if err != nil { - log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) - return nil - } - - if peer.Status.LastSeen.After(streamStartTime) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", - peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) - return nil - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) - if err != nil { + if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 71af0645c..ae3de8d79 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -61,7 +61,8 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error + MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error + MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 7ffc41d73..0486e63ec 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal } // MarkPeerConnected mocks base method. -func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { +func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime) + ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt) ret0, _ := ret[0].(error) return ret0 } // MarkPeerConnected indicates an expected call of MarkPeerConnected. -func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt) +} + +// MarkPeerDisconnected mocks base method. +func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected. +func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt) } // OnPeerDisconnected mocks base method. diff --git a/management/server/account_test.go b/management/server/account_test.go index 60720faa6..ba621030c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { }, false) require.NoError(t, err, "unable to add peer") - t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + t.Run("disconnect peer when session token matches", func(t *testing.T) { + streamStartTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err, "unable to get peer") require.True(t, peer.Status.Connected, "peer should be connected") - - streamStartTime := time.Now().UTC() + require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should equal the token we passed in") err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) require.NoError(t, err) @@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.False(t, peer.Status.Connected, "peer should be disconnected") + require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0") }) - t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) { + // Newer stream wins on connect (sets SessionStartedAt = now ns). + streamStartTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should be connected") - streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + // Older stream tries to mark disconnect with its own (older) session token — + // fencing kicks in and the write is dropped. + staleStreamStartTime := streamStartTime.Add(-1 * time.Hour) - err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime) require.NoError(t, err) peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, - "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + "peer should remain connected because the stored session is newer than the disconnect token") + require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should still hold the winning stream's token") }) - t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) { node2SyncTime := time.Now().UTC() - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano()) require.NoError(t, err, "node 2 should connect peer") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should be connected") - require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should equal node2SyncTime token") node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano()) require.NoError(t, err, "stale connect should not return error") peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should still be connected") - require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), - "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should NOT be overwritten by stale token from blocked goroutine") }) } +// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the +// fencing protocol under contention: many goroutines race to mark the +// same peer connected with distinct session tokens at the same time. +// The contract is that the highest token always wins and is what remains +// in the store, regardless of execution order. +func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to get account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + const workers = 16 + base := time.Now().UTC().UnixNano() + tokens := make([]int64, workers) + for i := range tokens { + // Spread tokens by 1ms so the comparison is unambiguous; the + // largest is index workers-1. + tokens[i] = base + int64(i)*int64(time.Millisecond) + } + expected := tokens[workers-1] + + var ready sync.WaitGroup + ready.Add(workers) + var start sync.WaitGroup + start.Add(1) + var done sync.WaitGroup + done.Add(workers) + + // require.* calls t.FailNow which is documented as unsafe from + // non-test goroutines (it calls runtime.Goexit on the wrong stack and + // races with the WaitGroup). Collect errors here and assert from the + // main goroutine after done.Wait(). + errs := make(chan error, workers) + + for i := 0; i < workers; i++ { + token := tokens[i] + go func() { + defer done.Done() + ready.Done() + start.Wait() + errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token) + }() + } + + ready.Wait() + start.Done() + done.Wait() + close(errs) + for err := range errs { + require.NoError(t, err, "MarkPeerConnected must not error under contention") + } + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected after the race") + require.Equal(t, expected, peer.Status.SessionStartedAt, + "the largest token must win regardless of execution order") +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index a1852a8bc..821e6ff55 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "net/http" + "net/url" "os" + "path" "strings" "github.com/dexidp/dex/storage" @@ -138,10 +140,13 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { return nil, fmt.Errorf("invalid IdP storage config: %w", err) } - // Build CLI redirect URIs including the device callback (both relative and absolute) + // Build CLI redirect URIs including the device callback. Dex uses the issuer-relative + // path (for example, /oauth2/device/callback) when completing the device flow, so + // include it explicitly in addition to the legacy bare path and absolute URL. cliRedirectURIs := c.CLIRedirectURIs cliRedirectURIs = append(cliRedirectURIs, "/device/callback") - cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback") + cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer)) + cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback") // Build dashboard redirect URIs including the OAuth callback for proxy authentication dashboardRedirectURIs := c.DashboardRedirectURIs @@ -154,6 +159,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { // MGMT api and the dashboard, adding baseURL means less configuration for the instance admin dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL) + redirectURIs := make([]string, 0) + redirectURIs = append(redirectURIs, cliRedirectURIs...) + redirectURIs = append(redirectURIs, dashboardRedirectURIs...) + cfg := &dex.YAMLConfig{ Issuer: c.Issuer, Storage: dex.Storage{ @@ -179,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { ID: staticClientDashboard, Name: "NetBird Dashboard", Public: true, - RedirectURIs: dashboardRedirectURIs, + RedirectURIs: redirectURIs, PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs), }, { ID: staticClientCLI, Name: "NetBird CLI", Public: true, - RedirectURIs: cliRedirectURIs, + RedirectURIs: redirectURIs, }, }, StaticConnectors: c.StaticConnectors, @@ -217,6 +226,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { return cfg, nil } +func issuerRelativeDeviceCallback(issuer string) string { + u, err := url.Parse(issuer) + if err != nil || u.Path == "" { + return "/device/callback" + } + return path.Join(u.Path, "/device/callback") +} + // Due to how the frontend generates the logout, sometimes it appends a trailing slash // and because Dex only allows exact matches, we need to make sure we always have both // versions of each provided uri @@ -299,7 +316,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) { } } - return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key))) + return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key))) } func validSessionCookieEncryptionKeyLength(length int) bool { diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index 09dc67614..91cd27aee 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -314,6 +314,34 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) { }) } +func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) { + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "https://example.com/oauth2", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(t.TempDir(), "dex.db"), + }, + }, + } + + yamlConfig, err := config.ToYAMLConfig() + require.NoError(t, err) + + var cliRedirectURIs []string + for _, client := range yamlConfig.StaticClients { + if client.ID == staticClientCLI { + cliRedirectURIs = client.RedirectURIs + break + } + } + require.NotEmpty(t, cliRedirectURIs) + assert.Contains(t, cliRedirectURIs, "/device/callback") + assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback") + assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback") +} + func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) { t.Setenv(sessionCookieEncryptionKeyEnv, "") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 08091d4b7..aba408184 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,7 +38,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error + MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) @@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { +func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + // Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing + // hook so tests that inject MarkPeerDisconnectedFunc actually observe + // disconnect events. Falls through to nil when no hook is set, which + // is the original behaviour. + if am.MarkPeerDisconnectedFunc != nil { + return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano()) + } return nil } @@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) + return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } +// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface +func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error { + if am.MarkPeerDisconnectedFunc != nil { + return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt) + } + return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented") +} + // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index c3b130ba2..34b681f51 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -16,7 +16,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -29,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/shared/management/status" ) @@ -63,56 +63,64 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) } -// MarkPeerConnected marks peer as connected (true) or disconnected (false) -// syncTime is used as the LastSeen timestamp and for stale request detection -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { - var peer *nbpeer.Peer - var settings *types.Settings - var expired bool - var err error - var skipped bool +// MarkPeerConnected marks a peer as connected with optimistic-locked +// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument +// is the start time of the gRPC sync stream that owns this update, +// expressed as Unix nanoseconds — only the call whose token is greater +// than what's stored wins. LastSeen is written by the database itself; +// we never pass it down. +// +// Disconnects use MarkPeerDisconnected and require the session to match +// exactly; see PeerStatus.SessionStartedAt for the protocol. +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { + start := time.Now() + defer func() { + am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start)) + }() - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) - if err != nil { - return err + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) + if err != nil { + outcome := telemetry.PeerStatusError + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + outcome = telemetry.PeerStatusPeerNotFound } - - if connected && !syncTime.After(peer.Status.LastSeen) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", - peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) - skipped = true - return nil - } - - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome) return err - }) - if skipped { + } + + updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt) + if err != nil { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError) + return err + } + if !updated { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale) + log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID) return nil } - if err != nil { - return err + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied) + + if am.geo != nil && realIP != nil { + am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP) } + expired := peer.Status != nil && peer.Status.LoginExpired + if peer.AddedWithSSOLogin() { - settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } - if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.schedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) - if err != nil { + if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) } } @@ -120,41 +128,60 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { - oldStatus := peer.Status.Copy() - newStatus := oldStatus - newStatus.LastSeen = syncTime - newStatus.Connected = connected - // whenever peer got connected that means that it logged in successfully - if newStatus.Connected { - newStatus.LoginExpired = false - } - peer.Status = newStatus +// MarkPeerDisconnected marks a peer as disconnected, but only when the +// stored session token matches the one passed in. A mismatch means a +// newer stream has already taken ownership of the peer — disconnects from +// the older stream are ignored. LastSeen is written by the database. +func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error { + start := time.Now() + defer func() { + am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start)) + }() - if geo != nil && realIP != nil { - location, err := geo.Lookup(realIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) - } else { - peer.Location.ConnectionIP = realIP - peer.Location.CountryCode = location.Country.ISOCode - peer.Location.CityName = location.City.Names.En - peer.Location.GeoNameID = location.City.GeonameID - err = transaction.SavePeerLocation(ctx, accountID, peer) - if err != nil { - log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) - } - } - } - - log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected) - - err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { - return false, err + outcome := telemetry.PeerStatusError + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + outcome = telemetry.PeerStatusPeerNotFound + } + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome) + return err } - return oldStatus.LoginExpired, nil + updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt) + if err != nil { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError) + return err + } + if !updated { + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale) + log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping", + peer.ID, sessionStartedAt) + return nil + } + am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied) + return nil +} + +// updatePeerLocationIfChanged refreshes the geolocation on a separate +// row update, only when the connection IP actually changed. Geo lookups +// are expensive so we skip same-IP reconnects. +func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) { + if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) { + return + } + location, err := am.geo.Lookup(realIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) + return + } + peer.Location.ConnectionIP = realIP + peer.Location.CountryCode = location.Country.ISOCode + peer.Location.CityName = location.City.Names.En + peer.Location.GeoNameID = location.City.GeonameID + if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil { + log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) + } } // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 17df761a1..2963dfcbd 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -74,8 +74,19 @@ type ProxyMeta struct { } type PeerStatus struct { //nolint:revive - // LastSeen is the last time peer was connected to the management service + // LastSeen is the last time the peer status was updated (i.e. the last + // time we observed the peer being alive on a sync stream). Written by + // the database (CURRENT_TIMESTAMP) — callers do not supply it. LastSeen time.Time + // SessionStartedAt records when the currently-active sync stream began, + // stored as Unix nanoseconds. It acts as the optimistic-locking token + // for status updates: a stream is only allowed to mutate the peer's + // status when its own token strictly exceeds the stored token (when connecting) + // or matches it exactly (for disconnects). Zero means "no + // active session". Integer nanoseconds are used so equality is + // precision-safe across drivers, and so the predicates compose to a + // single bigint comparison. + SessionStartedAt int64 // Connected indicates whether peer is connected to the management service or not Connected bool // LoginExpired @@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any { return meta } -// Copy PeerStatus +// Copy PeerStatus. SessionStartedAt must be propagated so clone-based +// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently +// reset the fencing token to zero — that would let any subsequent +// SavePeerStatus write reopen the optimistic-lock window. func (p *PeerStatus) Copy() *PeerStatus { return &PeerStatus{ LastSeen: p.LastSeen, + SessionStartedAt: p.SessionStartedAt, Connected: p.Connected, LoginExpired: p.LoginExpired, RequiresApproval: p.RequiresApproval, diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 893ee2168..8cf37de56 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -498,8 +498,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, peerCopy.Status = &peerStatus fieldsToUpdate := []string{ - "peer_status_last_seen", "peer_status_connected", - "peer_status_login_expired", "peer_status_required_approval", + "peer_status_last_seen", "peer_status_session_started_at", + "peer_status_connected", "peer_status_login_expired", + "peer_status_requires_approval", } result := s.db.Model(&nbpeer.Peer{}). Select(fieldsToUpdate). @@ -516,6 +517,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, return nil } +// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update. +// The peer is marked connected with the given session token only when +// the stored SessionStartedAt is strictly smaller than the incoming +// one — equivalently, when no newer stream has already taken ownership. +// The sentinel zero (set on peer creation or after a disconnect) counts +// as the smallest possible token. This is the write half of the +// fencing protocol described on PeerStatus.SessionStartedAt. +// +// The post-write side effects in the caller — geo lookup, +// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration, +// OnPeersUpdated — all run AFTER this method returns and are deliberately +// outside the database write so they cannot extend the row-lock window. +// +// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the +// moment the row is written. The caller never supplies LastSeen because +// the value would otherwise drift under lock contention — a Go-side +// time.Now() taken before the write can land minutes later than the +// actual UPDATE under load, which previously caused real ordering bugs. +func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + result := s.db.WithContext(ctx). + Model(&nbpeer.Peer{}). + Where(accountAndIDQueryCondition, accountID, peerID). + Where("peer_status_session_started_at < ?", newSessionStartedAt). + Updates(map[string]any{ + "peer_status_connected": true, + "peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"), + "peer_status_session_started_at": newSessionStartedAt, + "peer_status_login_expired": false, + }) + if result.Error != nil { + return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error) + } + return result.RowsAffected > 0, nil +} + +// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update. +// The peer is marked disconnected only when the stored SessionStartedAt +// matches the incoming token — meaning the stream that owns the current +// session is the one ending. If a newer stream has already replaced the +// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at +// write time; see MarkPeerConnectedIfNewerSession for the rationale. +// +// A zero sessionStartedAt is rejected at the call site; the underlying +// WHERE on equality would otherwise match every never-connected peer. +func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + if sessionStartedAt == 0 { + return false, nil + } + result := s.db.WithContext(ctx). + Model(&nbpeer.Peer{}). + Where(accountAndIDQueryCondition, accountID, peerID). + Where("peer_status_session_started_at = ?", sessionStartedAt). + Updates(map[string]any{ + "peer_status_connected": false, + "peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"), + "peer_status_session_started_at": int64(0), + }) + if result.Error != nil { + return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error) + } + return result.RowsAffected > 0, nil +} + func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer @@ -1723,9 +1787,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, - meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired, - peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, - location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1` + meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at, + peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, + location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 + FROM peers WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -1738,6 +1803,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee lastLogin, createdAt sql.NullTime sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool peerStatusLastSeen sql.NullTime + peerStatusSessionStartedAt sql.NullInt64 peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString @@ -1752,8 +1818,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities, - &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, - &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6) + &peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired, + &peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID, + &proxyEmbedded, &proxyCluster, &ipv6) if err == nil { if lastLogin.Valid { @@ -1780,6 +1847,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if peerStatusLastSeen.Valid { p.Status.LastSeen = peerStatusLastSeen.Time } + if peerStatusSessionStartedAt.Valid { + p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64 + } if peerStatusConnected.Valid { p.Status.Connected = peerStatusConnected.Bool } diff --git a/management/server/store/store.go b/management/server/store/store.go index aa601c33f..a723c1fc3 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -167,6 +167,21 @@ type Store interface { GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error + // MarkPeerConnectedIfNewerSession sets the peer to connected with the + // given session token, but only when the stored SessionStartedAt is + // strictly less than newSessionStartedAt (the sentinel zero counts as + // "older"). LastSeen is recorded by the database at the moment the + // row is updated — never by the caller — so it always reflects the + // real write time even under lock contention. + // Returns true when the update happened, false when this stream lost + // the race against a newer session. + MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) + // MarkPeerDisconnectedIfSameSession sets the peer to disconnected and + // resets SessionStartedAt to zero, but only when the stored + // SessionStartedAt equals the given sessionStartedAt. LastSeen is + // recorded by the database. Returns true when the update happened, + // false when a newer session has taken over. + MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 9780c521e..d51629606 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -2878,6 +2878,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status) } +// MarkPeerConnectedIfNewerSession mocks base method. +func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession. +func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt) +} + +// MarkPeerDisconnectedIfSameSession mocks base method. +func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession. +func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt) +} + // SavePolicy mocks base method. func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error { m.ctrl.T.Helper() diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index 518aae7eb..bb6fb7e12 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -16,6 +16,8 @@ type AccountManagerMetrics struct { getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram peerMetaUpdateCount metric.Int64Counter + peerStatusUpdateCounter metric.Int64Counter + peerStatusUpdateDurationMs metric.Float64Histogram } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -64,6 +66,24 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + // peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected + peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)")) + if err != nil { + return nil, err + } + + peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms", + metric.WithUnit("milliseconds"), + metric.WithExplicitBucketBoundaries( + 1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000, + ), + metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, @@ -71,10 +91,35 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account updateAccountPeersCounter: updateAccountPeersCounter, networkMapObjectCount: networkMapObjectCount, peerMetaUpdateCount: peerMetaUpdateCount, + peerStatusUpdateCounter: peerStatusUpdateCounter, + peerStatusUpdateDurationMs: peerStatusUpdateDurationMs, }, nil } +// PeerStatusOperation labels the kind of fence-locked peer status write. +type PeerStatusOperation string + +// PeerStatusOutcome labels how a fence-locked peer status write resolved. +type PeerStatusOutcome string + +const ( + PeerStatusConnect PeerStatusOperation = "connect" + PeerStatusDisconnect PeerStatusOperation = "disconnect" + + // PeerStatusApplied — the fence WHERE matched and the UPDATE landed. + PeerStatusApplied PeerStatusOutcome = "applied" + // PeerStatusStale — the fence WHERE rejected the write because a + // newer session has already taken ownership (connect: stored token + // >= incoming; disconnect: stored token != incoming). + PeerStatusStale PeerStatusOutcome = "stale" + // PeerStatusError — the store returned a non-NotFound error. + PeerStatusError PeerStatusOutcome = "error" + // PeerStatusPeerNotFound — the peer lookup failed (the peer was + // deleted between the gRPC sync handshake and the status write). + PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found" +) + // CountUpdateAccountPeersDuration counts the duration of updating account peers func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) { metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6) @@ -104,3 +149,23 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource, func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) } + +// CountPeerStatusUpdate increments the connect/disconnect counter, +// labeled by operation and outcome. Both labels are bounded enums. +func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) { + metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1, + metric.WithAttributes( + attribute.String("operation", string(op)), + attribute.String("outcome", string(outcome)), + ), + ) +} + +// RecordPeerStatusUpdateDuration records the wall-clock time spent +// running a peer status update (including post-write side effects), +// labeled by operation. +func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) { + metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6, + metric.WithAttributes(attribute.String("operation", string(op))), + ) +} diff --git a/management/server/telemetry/app_metrics.go b/management/server/telemetry/app_metrics.go index 1fd78bc3a..fd9087a96 100644 --- a/management/server/telemetry/app_metrics.go +++ b/management/server/telemetry/app_metrics.go @@ -29,6 +29,7 @@ type MockAppMetrics struct { StoreMetricsFunc func() *StoreMetrics UpdateChannelMetricsFunc func() *UpdateChannelMetrics AddAccountManagerMetricsFunc func() *AccountManagerMetrics + EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics } // GetMeter mocks the GetMeter function of the AppMetrics interface @@ -103,6 +104,14 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics { return nil } +// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface +func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics { + if mock.EphemeralPeersMetricsFunc != nil { + return mock.EphemeralPeersMetricsFunc() + } + return nil +} + // AppMetrics is metrics interface type AppMetrics interface { GetMeter() metric2.Meter @@ -114,6 +123,7 @@ type AppMetrics interface { StoreMetrics() *StoreMetrics UpdateChannelMetrics() *UpdateChannelMetrics AccountManagerMetrics() *AccountManagerMetrics + EphemeralPeersMetrics() *EphemeralPeersMetrics } // defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/ @@ -129,6 +139,7 @@ type defaultAppMetrics struct { storeMetrics *StoreMetrics updateChannelMetrics *UpdateChannelMetrics accountManagerMetrics *AccountManagerMetrics + ephemeralMetrics *EphemeralPeersMetrics } // IDPMetrics returns metrics for the idp package @@ -161,6 +172,11 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr return appMetrics.accountManagerMetrics } +// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop +func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics { + return appMetrics.ephemeralMetrics +} + // Close stop application metrics HTTP handler and closes listener. func (appMetrics *defaultAppMetrics) Close() error { if appMetrics.listener == nil { @@ -245,6 +261,11 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err) } + ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter) + if err != nil { + return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err) + } + return &defaultAppMetrics{ Meter: meter, ctx: ctx, @@ -254,6 +275,7 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { storeMetrics: storeMetrics, updateChannelMetrics: updateChannelMetrics, accountManagerMetrics: accountManagerMetrics, + ephemeralMetrics: ephemeralMetrics, }, nil } @@ -290,6 +312,11 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err) } + ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter) + if err != nil { + return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err) + } + return &defaultAppMetrics{ Meter: meter, ctx: ctx, @@ -300,5 +327,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric storeMetrics: storeMetrics, updateChannelMetrics: updateChannelMetrics, accountManagerMetrics: accountManagerMetrics, + ephemeralMetrics: ephemeralMetrics, }, nil } diff --git a/management/server/telemetry/ephemeral_metrics.go b/management/server/telemetry/ephemeral_metrics.go new file mode 100644 index 000000000..a7fb432f8 --- /dev/null +++ b/management/server/telemetry/ephemeral_metrics.go @@ -0,0 +1,115 @@ +package telemetry + +import ( + "context" + + "go.opentelemetry.io/otel/metric" +) + +// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how +// many peers are currently scheduled for deletion, how many tick runs +// the cleaner has performed, how many peers it has removed, and how +// many delete batches failed. +type EphemeralPeersMetrics struct { + ctx context.Context + + pending metric.Int64UpDownCounter + cleanupRuns metric.Int64Counter + peersCleaned metric.Int64Counter + errors metric.Int64Counter +} + +// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters. +func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) { + pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending", + metric.WithUnit("1"), + metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up")) + if err != nil { + return nil, err + } + + cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer")) + if err != nil { + return nil, err + } + + peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter", + metric.WithUnit("1"), + metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop")) + if err != nil { + return nil, err + } + + errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter", + metric.WithUnit("1"), + metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete")) + if err != nil { + return nil, err + } + + return &EphemeralPeersMetrics{ + ctx: ctx, + pending: pending, + cleanupRuns: cleanupRuns, + peersCleaned: peersCleaned, + errors: errors, + }, nil +} + +// All methods are nil-receiver safe so callers that haven't wired metrics +// (tests, self-hosted with metrics off) can invoke them unconditionally. + +// IncPending bumps the pending gauge when a peer is added to the cleanup list. +func (m *EphemeralPeersMetrics) IncPending() { + if m == nil { + return + } + m.pending.Add(m.ctx, 1) +} + +// AddPending bumps the pending gauge by n — used at startup when the +// initial set of ephemeral peers is loaded from the store. +func (m *EphemeralPeersMetrics) AddPending(n int64) { + if m == nil || n <= 0 { + return + } + m.pending.Add(m.ctx, n) +} + +// DecPending decreases the pending gauge — used both when a peer reconnects +// before its deadline (removed from the list) and when a cleanup tick +// actually deletes it. +func (m *EphemeralPeersMetrics) DecPending(n int64) { + if m == nil || n <= 0 { + return + } + m.pending.Add(m.ctx, -n) +} + +// CountCleanupRun records one cleanup pass that processed >0 peers. Idle +// ticks (nothing to do) deliberately don't increment so the rate +// reflects useful work. +func (m *EphemeralPeersMetrics) CountCleanupRun() { + if m == nil { + return + } + m.cleanupRuns.Add(m.ctx, 1) +} + +// CountPeersCleaned records the number of peers a single tick deleted. +func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) { + if m == nil || n <= 0 { + return + } + m.peersCleaned.Add(m.ctx, n) +} + +// CountCleanupError records a failed delete batch. +func (m *EphemeralPeersMetrics) CountCleanupError() { + if m == nil { + return + } + m.errors.Add(m.ctx, 1) +}