diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 7da1e6898..bc14f1618 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -188,7 +188,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore { - tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100) + if err != nil { + log.Fatalf("failed to create proxy token store: %v", err) + } log.Info("One-time token store initialized for proxy authentication") return tokenStore }) diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index dcc37c639..462198124 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -1,28 +1,21 @@ package grpc import ( + "context" "crypto/rand" + "crypto/sha256" "crypto/subtle" "encoding/base64" + "encoding/hex" "fmt" - "sync" "time" + "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" + + nbcache "github.com/netbirdio/netbird/management/server/cache" ) -// OneTimeTokenStore manages short-lived, single-use authentication tokens -// for proxy-to-management RPC authentication. Tokens are generated when -// a service is created and must be used exactly once by the proxy -// to authenticate a subsequent RPC call. -type OneTimeTokenStore struct { - tokens map[string]*tokenMetadata - mu sync.RWMutex - cleanup *time.Ticker - cleanupDone chan struct{} -} - -// tokenMetadata stores information about a one-time token type tokenMetadata struct { ServiceID string AccountID string @@ -30,20 +23,24 @@ type tokenMetadata struct { CreatedAt time.Time } -// NewOneTimeTokenStore creates a new token store with automatic cleanup -// of expired tokens. The cleanupInterval determines how often expired -// tokens are removed from memory. -func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { - store := &OneTimeTokenStore{ - tokens: make(map[string]*tokenMetadata), - cleanup: time.NewTicker(cleanupInterval), - cleanupDone: make(chan struct{}), +// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC. +// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. +type OneTimeTokenStore struct { + store store.StoreInterface + ctx context.Context +} + +// NewOneTimeTokenStore creates a token store with automatic backend selection +func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) { + cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) + if err != nil { + return nil, fmt.Errorf("failed to create cache store: %w", err) } - // Start background cleanup goroutine - go store.cleanupExpired() - - return store + return &OneTimeTokenStore{ + store: cacheStore, + ctx: ctx, + }, nil } // GenerateToken creates a new cryptographically secure one-time token @@ -52,25 +49,25 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { // // Returns the generated token string or an error if random generation fails. func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) { - // Generate 32 bytes (256 bits) of cryptographically secure random data randomBytes := make([]byte, 32) if _, err := rand.Read(randomBytes); err != nil { return "", fmt.Errorf("failed to generate random token: %w", err) } - // Encode as URL-safe base64 for easy transmission in gRPC token := base64.URLEncoding.EncodeToString(randomBytes) + hashedToken := hashToken(token) - s.mu.Lock() - defer s.mu.Unlock() - - s.tokens[token] = &tokenMetadata{ + metadata := &tokenMetadata{ ServiceID: serviceID, AccountID: accountID, ExpiresAt: time.Now().Add(ttl), CreatedAt: time.Now(), } + if err := s.store.Set(s.ctx, hashedToken, metadata, store.WithExpiration(ttl)); err != nil { + return "", fmt.Errorf("failed to store token: %w", err) + } + log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)", serviceID, accountID, ttl) @@ -88,80 +85,46 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time. // - Account ID doesn't match // - Reverse proxy ID doesn't match func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { - s.mu.Lock() - defer s.mu.Unlock() + hashedToken := hashToken(token) - metadata, exists := s.tokens[token] - if !exists { - log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", - serviceID, accountID) + value, err := s.store.Get(s.ctx, hashedToken) + if err != nil { + log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("invalid token") } - // Check expiration + metadata, ok := value.(*tokenMetadata) + if !ok { + log.Warnf("Token validation failed: invalid metadata type (proxy: %s, account: %s)", serviceID, accountID) + return fmt.Errorf("invalid token metadata") + } + if time.Now().After(metadata.ExpiresAt) { - delete(s.tokens, token) - log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", - serviceID, accountID) + s.store.Delete(s.ctx, hashedToken) + log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("token expired") } - // Validate account ID using constant-time comparison (prevents timing attacks) if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 { - log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", - metadata.AccountID, accountID) + log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID) return fmt.Errorf("account ID mismatch") } - // Validate service ID using constant-time comparison if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 { - log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", - metadata.ServiceID, serviceID) + log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID) return fmt.Errorf("service ID mismatch") } - // Delete token immediately to enforce single-use - delete(s.tokens, token) + if err := s.store.Delete(s.ctx, hashedToken); err != nil { + log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err) + } - log.Infof("Token validated and consumed for proxy %s in account %s", - serviceID, accountID) + log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID) return nil } -// cleanupExpired removes expired tokens in the background to prevent memory leaks -func (s *OneTimeTokenStore) cleanupExpired() { - for { - select { - case <-s.cleanup.C: - s.mu.Lock() - now := time.Now() - removed := 0 - for token, metadata := range s.tokens { - if now.After(metadata.ExpiresAt) { - delete(s.tokens, token) - removed++ - } - } - if removed > 0 { - log.Debugf("Cleaned up %d expired one-time tokens", removed) - } - s.mu.Unlock() - case <-s.cleanupDone: - return - } - } -} - -// Close stops the cleanup goroutine and releases resources -func (s *OneTimeTokenStore) Close() { - s.cleanup.Stop() - close(s.cleanupDone) -} - -// GetTokenCount returns the current number of tokens in the store (for debugging/metrics) -func (s *OneTimeTokenStore) GetTokenCount() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.tokens) +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 4771d35af..76bd1a247 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -486,33 +486,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } } -// GetAvailableClusters returns information about all connected proxy clusters. -func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo { - clusterCounts := make(map[string]int) - s.clusterProxies.Range(func(key, value interface{}) bool { - clusterAddr := key.(string) - proxySet := value.(*sync.Map) - count := 0 - proxySet.Range(func(_, _ interface{}) bool { - count++ - return true - }) - if count > 0 { - clusterCounts[clusterAddr] = count - } - return true - }) - - clusters := make([]ClusterInfo, 0, len(clusterCounts)) - for addr, count := range clusterCounts { - clusters = append(clusters, ClusterInfo{ - Address: addr, - ConnectedProxies: count, - }) - } - return clusters -} - func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 4c84e6010..714569934 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -1,6 +1,7 @@ package grpc import ( + "context" "crypto/rand" "encoding/base64" "strings" @@ -41,8 +42,8 @@ func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping { } func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -96,8 +97,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { } func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -131,8 +132,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) s := &ProxyServiceServer{ tokenStore: tokenStore, diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index f76d3ada0..d40b32b7f 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -37,7 +37,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { proxyManager := &testValidateSessionProxyManager{store: testStore} usersManager := &testValidateSessionUsersManager{store: testStore} - proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager) + tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager) proxyService.SetProxyManager(proxyManager) createTestProxies(t, ctx, testStore) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 0a9a560cd..e1e6b5680 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -178,7 +178,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcServer := newFakeOIDCServer() - tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + require.NoError(t, err) usersManager := users.NewManager(testStore) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index f5c2aafa6..957d23b44 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -91,7 +91,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) + proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + if err != nil { + t.Fatalf("Failed to create proxy token store: %v", err) + } proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager) reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager) diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 53d7019f7..47e735e22 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -112,7 +112,8 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { } // Create real token store - tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) // Create real users manager usersManager := users.NewManager(testStore)