diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 24eb5cb92..f7c2268eb 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -31,18 +31,16 @@ type managerImpl struct { accountManager account.Manager permissionsManager permissions.Manager proxyGRPCServer *nbgrpc.ProxyServiceServer - tokenStore *nbgrpc.OneTimeTokenStore clusterDeriver ClusterDeriver } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore, clusterDeriver ClusterDeriver) reverseproxy.Manager { +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { return &managerImpl{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, proxyGRPCServer: proxyGRPCServer, - tokenStore: tokenStore, clusterDeriver: clusterDeriver, } } @@ -187,11 +185,6 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin return nil, err } - token, err := m.tokenStore.GenerateToken(accountID, service.ID, 5*time.Minute) - if err != nil { - return nil, fmt.Errorf("failed to generate authentication token: %w", err) - } - m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta()) err = m.replaceHostByLookup(ctx, accountID, service) @@ -199,7 +192,7 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -293,22 +286,17 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - token, err := m.tokenStore.GenerateToken(accountID, service.ID, 5*time.Minute) - if err != nil { - return nil, fmt.Errorf("failed to generate authentication token: %w", err) - } - + oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() switch { case domainChanged && oldCluster != service.ProxyCluster: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), oldCluster) - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster) + m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) case !service.Enabled && serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster) case service.Enabled && serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) default: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) - + m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster) } m.accountManager.UpdateAccountPeers(ctx, accountID) diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 0ddf02046..d83dde58d 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -192,7 +192,7 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { return Create(s, func() reverseproxy.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ProxyTokenStore(), s.ReverseProxyDomainManager()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) }) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 9bd55151a..aae6ef2f9 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -345,14 +345,19 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA } // SendServiceUpdate broadcasts a service update to all connected proxy servers. -// Management should call this when services are created/updated/removed +// Management should call this when services are created/updated/removed. +// For create/update operations a unique one-time auth token is generated per +// proxy so that every replica can independently authenticate with management. func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { - // Send it to all connected proxy servers log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) + msg := s.perProxyMessage(update, conn.proxyID) + if msg == nil { + return true + } select { - case conn.sendChan <- update: + case conn.sendChan <- msg: log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) @@ -420,6 +425,8 @@ func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). +// For create/update operations a unique one-time auth token is generated per +// proxy so that every replica can independently authenticate with management. func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { if clusterAddr == "" { s.SendServiceUpdate(update) @@ -437,8 +444,12 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi proxyID := key.(string) if connVal, ok := s.connectedProxies.Load(proxyID); ok { conn := connVal.(*proxyConnection) + msg := s.perProxyMessage(update, proxyID) + if msg == nil { + return true + } select { - case conn.sendChan <- update: + case conn.sendChan <- msg: log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) @@ -448,6 +459,42 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi }) } +// perProxyMessage returns a copy of update with a fresh one-time token for +// create/update operations. For delete operations the original message is +// returned unchanged because proxies do not need to authenticate for removal. +// Returns nil if token generation fails (the proxy should be skipped). +func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping { + if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" { + return update + } + + token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute) + if err != nil { + log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) + return nil + } + + msg := shallowCloneMapping(update) + msg.AuthToken = token + return msg +} + +// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the +// same slice/pointer fields. Only scalar fields that differ per proxy (AuthToken) +// should be set on the copy. +func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { + return &proto.ProxyMapping{ + Type: m.Type, + Id: m.Id, + AccountId: m.AccountId, + Domain: m.Domain, + Path: m.Path, + Auth: m.Auth, + PassHostHeader: m.PassHostHeader, + RewriteRedirects: m.RewriteRedirects, + } +} + // GetAvailableClusters returns information about all connected proxy clusters. func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo { clusterCounts := make(map[string]int) diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go new file mode 100644 index 000000000..589c57611 --- /dev/null +++ b/management/internals/shared/grpc/proxy_test.go @@ -0,0 +1,164 @@ +package grpc + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// registerFakeProxy adds a fake proxy connection to the server's internal maps +// and returns the channel where messages will be received. +func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping { + ch := make(chan *proto.ProxyMapping, 10) + conn := &proxyConnection{ + proxyID: proxyID, + address: clusterAddr, + sendChan: ch, + } + s.connectedProxies.Store(proxyID, conn) + + proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) + proxySet.(*sync.Map).Store(proxyID, struct{}{}) + + return ch +} + +func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping { + select { + case msg := <-ch: + return msg + case <-time.After(time.Second): + return nil + } +} + +func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { + tokenStore := NewOneTimeTokenStore(time.Hour) + defer tokenStore.Close() + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + updatesChan: make(chan *proto.ProxyMapping, 100), + } + + const cluster = "proxy.example.com" + const numProxies = 3 + + channels := make([]chan *proto.ProxyMapping, numProxies) + for i := range numProxies { + id := "proxy-" + string(rune('a'+i)) + channels[i] = registerFakeProxy(s, id, cluster) + } + + update := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-1", + AccountId: "account-1", + Domain: "test.example.com", + Path: []*proto.PathMapping{ + {Path: "/", Target: "http://10.0.0.1:8080/"}, + }, + } + + s.SendServiceUpdateToCluster(update, cluster) + + tokens := make([]string, numProxies) + for i, ch := range channels { + msg := drainChannel(ch) + require.NotNil(t, msg, "proxy %d should receive a message", i) + assert.Equal(t, update.Domain, msg.Domain) + assert.Equal(t, update.Id, msg.Id) + assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) + tokens[i] = msg.AuthToken + } + + // All tokens must be unique + tokenSet := make(map[string]struct{}) + for i, tok := range tokens { + _, exists := tokenSet[tok] + assert.False(t, exists, "proxy %d got duplicate token", i) + tokenSet[tok] = struct{}{} + } + + // Each token must be independently consumable + for i, tok := range tokens { + err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1") + assert.NoError(t, err, "proxy %d token should validate successfully", i) + } +} + +func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { + tokenStore := NewOneTimeTokenStore(time.Hour) + defer tokenStore.Close() + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + updatesChan: make(chan *proto.ProxyMapping, 100), + } + + const cluster = "proxy.example.com" + ch1 := registerFakeProxy(s, "proxy-a", cluster) + ch2 := registerFakeProxy(s, "proxy-b", cluster) + + update := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, + Id: "service-1", + AccountId: "account-1", + Domain: "test.example.com", + } + + s.SendServiceUpdateToCluster(update, cluster) + + msg1 := drainChannel(ch1) + msg2 := drainChannel(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) + + // Delete operations should not generate tokens + assert.Empty(t, msg1.AuthToken) + assert.Empty(t, msg2.AuthToken) + + // No tokens should have been created + assert.Equal(t, 0, tokenStore.GetTokenCount()) +} + +func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { + tokenStore := NewOneTimeTokenStore(time.Hour) + defer tokenStore.Close() + + s := &ProxyServiceServer{ + tokenStore: tokenStore, + updatesChan: make(chan *proto.ProxyMapping, 100), + } + + // Register proxies in different clusters (SendServiceUpdate broadcasts to all) + ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") + ch2 := registerFakeProxy(s, "proxy-b", "cluster-b") + + update := &proto.ProxyMapping{ + Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, + Id: "service-1", + AccountId: "account-1", + Domain: "test.example.com", + } + + s.SendServiceUpdate(update) + + msg1 := drainChannel(ch1) + msg2 := drainChannel(ch2) + require.NotNil(t, msg1) + require.NotNil(t, msg2) + + assert.NotEmpty(t, msg1.AuthToken) + assert.NotEmpty(t, msg2.AuthToken) + assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy") + + // Both tokens should validate + assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1")) + assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1")) +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 7c71b3241..44bb0fb1c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3122,7 +3122,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil, nil)) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil)) return manager, updateManager, nil } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index eecbe18a0..78fcb39f2 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -94,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) domainManager := domain.NewManager(store, proxyServiceServer) - reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, proxyTokenStore, domainManager) + reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager) proxyServiceServer.SetProxyManager(reverseProxyManager) am.SetServiceManager(reverseProxyManager)