diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index aa7cd8630..53c52b3aa 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -11,9 +11,9 @@ import ( // Manager defines the interface for proxy operations type Manager interface { - Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error - Disconnect(ctx context.Context, proxyID string) error - Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) + Disconnect(ctx context.Context, proxyID, sessionID string) error + Heartbeat(ctx context.Context, p *Proxy) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusters(ctx context.Context) ([]Cluster, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index d13334e83..341e8c943 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -13,7 +13,8 @@ import ( // store defines the interface for proxy persistence operations type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + DisconnectProxy(ctx context.Context, proxyID, sessionID string) error + UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool @@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) { // Connect registers a new proxy connection in the database. // capabilities may be nil for old proxies that do not report them. -func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { +func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { now := time.Now() var caps proxy.Capabilities if capabilities != nil { @@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress } p := &proxy.Proxy{ ID: proxyID, + SessionID: sessionID, ClusterAddress: clusterAddress, IPAddress: ipAddress, LastSeen: now, @@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress if err := m.store.SaveProxy(ctx, p); err != nil { log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) - return err + return nil, err } log.WithContext(ctx).WithFields(log.Fields{ "proxyID": proxyID, + "sessionID": sessionID, "clusterAddress": clusterAddress, "ipAddress": ipAddress, }).Info("proxy connected") - return nil + return p, nil } -// Disconnect marks a proxy as disconnected in the database -func (m Manager) Disconnect(ctx context.Context, proxyID string) error { - now := time.Now() - p := &proxy.Proxy{ - ID: proxyID, - Status: "disconnected", - DisconnectedAt: &now, - LastSeen: now, - } - - if err := m.store.SaveProxy(ctx, p); err != nil { - log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) +// Disconnect marks a proxy as disconnected in the database. +func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error { + if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err) return err } log.WithContext(ctx).WithFields(log.Fields{ - "proxyID": proxyID, + "proxyID": proxyID, + "sessionID": sessionID, }).Info("proxy disconnected") return nil } -// Heartbeat updates the proxy's last seen timestamp -func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { - if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) +// Heartbeat updates the proxy's last seen timestamp. +func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { + if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { + log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) return err } - log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID) + log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID) m.metrics.IncrementProxyHeartbeatCount() return nil } diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index 282ca0ba5..98d97b3c6 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte } // Connect mocks base method. -func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { +func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) + ret0, _ := ret[0].(*Proxy) + ret1, _ := ret[1].(error) + return ret0, ret1 } // Connect indicates an expected call of Connect. -func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) } // Disconnect mocks base method. -func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error { +func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID) + ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID) ret0, _ := ret[0].(error) return ret0 } // Disconnect indicates an expected call of Disconnect. -func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID) } // GetActiveClusterAddresses mocks base method. @@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca } // Heartbeat mocks base method. -func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "Heartbeat", ctx, p) ret0, _ := ret[0].(error) return ret0 } // Heartbeat indicates an expected call of Heartbeat. -func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) } // MockController is a mock of Controller interface. diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 339c82446..dcedb8811 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -18,12 +18,13 @@ type Capabilities struct { // Proxy represents a reverse proxy instance type Proxy struct { ID string `gorm:"primaryKey;type:varchar(255)"` + SessionID string `gorm:"type:varchar(36)"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` IPAddress string `gorm:"type:varchar(45)"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` ConnectedAt *time.Time DisconnectedAt *time.Time - Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Capabilities Capabilities `gorm:"embedded"` CreatedAt time.Time UpdatedAt time.Time diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index a5e352e75..d811a0f69 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -16,6 +16,7 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" "google.golang.org/grpc/codes" @@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string + sessionID string address string capabilities *proto.ProxyCapabilities stream proto.ProxyService_GetMappingUpdateServer @@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy address is invalid") } + sessionID := uuid.NewString() + + if old, loaded := s.connectedProxies.Load(proxyID); loaded { + oldConn := old.(*proxyConnection) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "old_session_id": oldConn.sessionID, + "new_session_id": sessionID, + }).Info("Superseding existing proxy connection") + oldConn.cancel() + } + connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ proxyID: proxyID, + sessionID: sessionID, address: proxyAddress, capabilities: req.GetCapabilities(), stream: stream, @@ -188,12 +203,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest caps = &proxy.Capabilities{ SupportsCustomPorts: c.SupportsCustomPorts, RequireSubdomain: c.RequireSubdomain, - SupportsCrowdsec: c.SupportsCrowdsec, + SupportsCrowdsec: c.SupportsCrowdsec, } } - if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { + proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps) + if err != nil { log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) - s.connectedProxies.Delete(proxyID) + s.connectedProxies.CompareAndDelete(proxyID, conn) if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr) } @@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest log.WithFields(log.Fields{ "proxy_id": proxyID, + "session_id": sessionID, "address": proxyAddress, "cluster_addr": proxyAddress, "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { - if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil { - log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + if !s.connectedProxies.CompareAndDelete(proxyID, conn) { + log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID) + cancel() + return } - s.connectedProxies.Delete(proxyID) if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) } + if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + } cancel() - log.Infof("Proxy %s disconnected", proxyID) + log.Infof("Proxy %s session %s disconnected", proxyID, sessionID) }() if err := s.sendSnapshot(ctx, conn); err != nil { @@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest errChan := make(chan error, 2) go s.sender(conn, errChan) - // Start heartbeat goroutine - go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) + go s.heartbeat(connCtx, proxyRecord) select { case err := <-errChan: + log.WithContext(ctx).Warnf("Failed to send update: %v", err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err) case <-connCtx.Done(): + log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID) return connCtx.Err() } } // heartbeat updates the proxy's last_seen timestamp every minute -func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) { +func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { - log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + if err := s.proxyManager.Heartbeat(ctx, p); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err) } case <-ctx.Done(): + log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID) return } } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0a716d08d..1fa3d08ee 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5437,13 +5437,35 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { return nil } -// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist -func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +// DisconnectProxy marks a proxy as disconnected only if the session ID matches. +// This prevents a slow-to-close old session from overwriting a newer reconnection. +func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + now := time.Now() + result := s.db. + Model(&proxy.Proxy{}). + Where("id = ? AND session_id = ?", proxyID, sessionID). + Updates(map[string]any{ + "status": "disconnected", + "disconnected_at": now, + "last_seen": now, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, result.Error) + return status.Errorf(status.Internal, "failed to disconnect proxy") + } + if result.RowsAffected == 0 { + log.WithContext(ctx).Debugf("proxy %s session %s: no row updated (superseded by newer session)", proxyID, sessionID) + } + return nil +} + +// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session. +func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { now := time.Now() result := s.db. Model(&proxy.Proxy{}). - Where("id = ? AND status = ?", proxyID, "connected"). + Where("id = ? AND session_id = ?", p.ID, p.SessionID). Update("last_seen", now) if result.Error != nil { @@ -5452,17 +5474,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd } if result.RowsAffected == 0 { - p := &proxy.Proxy{ - ID: proxyID, - ClusterAddress: clusterAddress, - IPAddress: ipAddress, - LastSeen: now, - ConnectedAt: &now, - Status: "connected", - } - if err := s.db.Save(p).Error; err != nil { - log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err) - return status.Errorf(status.Internal, "failed to create proxy on heartbeat") + p.LastSeen = now + p.ConnectedAt = &now + p.Status = "connected" + if err := s.db.Create(p).Error; err != nil { + log.WithContext(ctx).Debugf("proxy %s session %s: heartbeat fallback insert skipped: %v", p.ID, p.SessionID, err) } } diff --git a/management/server/store/store.go b/management/server/store/store.go index 0d8b0678a..447c85547 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -284,7 +284,8 @@ type Store interface { DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error - UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + DisconnectProxy(ctx context.Context, proxyID, sessionID string) error + UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index beee13d96..d8bd826a8 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) } + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -2799,6 +2800,20 @@ func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) } +// DisconnectProxy mocks base method. +func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID, sessionID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DisconnectProxy indicates an expected call of DisconnectProxy. +func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID, sessionID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID, sessionID) +} + // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() @@ -2995,17 +3010,17 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} } // UpdateProxyHeartbeat mocks base method. -func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { +func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID, clusterAddress, ipAddress) + ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, p) ret0, _ := ret[0].(error) return ret0 } // UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. -func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, p interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID, clusterAddress, ipAddress) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, p) } // UpdateService mocks base method. diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 4b1ecf922..e9eae3210 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -201,15 +201,15 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, // testProxyManager is a mock implementation of proxy.Manager for testing. type testProxyManager struct{} -func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { +func (m *testProxyManager) Connect(_ context.Context, proxyID, sessionID, _, _ string, _ *nbproxy.Capabilities) (*nbproxy.Proxy, error) { + return &nbproxy.Proxy{ID: proxyID, SessionID: sessionID, Status: "connected"}, nil +} + +func (m *testProxyManager) Disconnect(_ context.Context, _, _ string) error { return nil } -func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { - return nil -} - -func (m *testProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { +func (m *testProxyManager) Heartbeat(_ context.Context, _ *nbproxy.Proxy) error { return nil } @@ -656,3 +656,72 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID) } } + +// TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState verifies that +// when a proxy reconnects before the old stream's cleanup runs, the new +// connection is NOT removed by the stale defer. +func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T) { + setup := setupIntegrationTest(t) + defer setup.cleanup() + + clusterAddress := "test.proxy.io" + proxyID := "test-proxy-race" + + conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer conn.Close() + + client := proto.NewProxyServiceClient(conn) + + ctx1, cancel1 := context.WithCancel(context.Background()) + stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: clusterAddress, + }) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, err := stream1.Recv() + require.NoError(t, err) + } + + require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, + "proxy should be registered after first connection") + + ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{ + ProxyId: proxyID, + Version: "test-v1", + Address: clusterAddress, + }) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, err := stream2.Recv() + require.NoError(t, err) + } + + cancel1() + + time.Sleep(200 * time.Millisecond) + + assert.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID, + "proxy should still be registered after old connection cleanup — old defer must not remove new connection") + + setup.proxyService.SendServiceUpdate(&proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "rp-1", + AccountId: "test-account-1", + Domain: "app1.test.proxy.io", + }}, + }) + + msg, err := stream2.Recv() + require.NoError(t, err, "new stream should still receive updates") + require.NotEmpty(t, msg.GetMapping(), "update should contain the mapping") + assert.Equal(t, "rp-1", msg.GetMapping()[0].GetId()) +}