From 30c02ab78c7bf05d610a636f7a2446144dd0ba01 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 9 Mar 2026 12:23:06 +0100 Subject: [PATCH] [management] use the cache for the pkce state (#5516) --- .../service/manager/manager_test.go | 15 +++-- management/internals/server/boot.go | 12 +++- management/internals/server/server.go | 1 - .../internals/shared/grpc/pkce_verifier.go | 61 +++++++++++++++++++ management/internals/shared/grpc/proxy.go | 59 +++--------------- .../internals/shared/grpc/proxy_test.go | 55 +++++++++++++---- .../shared/grpc/validate_session_test.go | 5 +- management/server/account_test.go | 2 +- .../proxy/auth_callback_integration_test.go | 4 ++ .../testing/testing_tools/channel/channel.go | 6 +- proxy/management_integration_test.go | 4 ++ 11 files changed, 152 insertions(+), 72 deletions(-) create mode 100644 management/internals/shared/grpc/pkce_verifier.go diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 0cb8fa02a..ba4e1c805 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -423,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { t.Helper() tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) require.NoError(t, err) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) - t.Cleanup(srv.Close) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) return srv } @@ -703,8 +704,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) require.NoError(t, err) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) - t.Cleanup(proxySrv.Close) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) @@ -1134,8 +1136,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) require.NoError(t, err) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) - t.Cleanup(proxySrv.Close) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) require.NoError(t, err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 2049f0051..eb13a15e3 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -168,7 +168,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) s.AfterInit(func(s *BaseServer) { proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetProxyController(s.ServiceProxyController()) @@ -203,6 +203,16 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { }) } +func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { + return Create(s, func() *nbgrpc.PKCEVerifierStore { + pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) + if err != nil { + log.Fatalf("failed to create PKCE verifier store: %v", err) + } + return pkceStore + }) +} + func (s *BaseServer) AccessLogsManager() accesslogs.Manager { return Create(s, func() accesslogs.Manager { accessLogManager := accesslogsmanager.NewManager(s.Store(), s.PermissionsManager(), s.GeoLocationManager()) diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 573983a79..9b8716da1 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -248,7 +248,6 @@ func (s *BaseServer) Stop() error { _ = s.certManager.Listener().Close() } s.GRPCServer().Stop() - s.ReverseProxyGRPCServer().Close() if s.proxyAuthClose != nil { s.proxyAuthClose() s.proxyAuthClose = nil diff --git a/management/internals/shared/grpc/pkce_verifier.go b/management/internals/shared/grpc/pkce_verifier.go new file mode 100644 index 000000000..441e8b051 --- /dev/null +++ b/management/internals/shared/grpc/pkce_verifier.go @@ -0,0 +1,61 @@ +package grpc + +import ( + "context" + "fmt" + "time" + + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" + log "github.com/sirupsen/logrus" + + nbcache "github.com/netbirdio/netbird/management/server/cache" +) + +// PKCEVerifierStore manages PKCE verifiers for OAuth flows. +// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. +type PKCEVerifierStore struct { + cache *cache.Cache[string] + ctx context.Context +} + +// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection +func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) { + cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) + if err != nil { + return nil, fmt.Errorf("failed to create cache store: %w", err) + } + + return &PKCEVerifierStore{ + cache: cache.New[string](cacheStore), + ctx: ctx, + }, nil +} + +// Store saves a PKCE verifier associated with an OAuth state parameter. +// The verifier is stored with the specified TTL and will be automatically deleted after expiration. +func (s *PKCEVerifierStore) Store(state, verifier string, ttl time.Duration) error { + if err := s.cache.Set(s.ctx, state, verifier, store.WithExpiration(ttl)); err != nil { + return fmt.Errorf("failed to store PKCE verifier: %w", err) + } + + log.Debugf("Stored PKCE verifier for state (expires in %s)", ttl) + return nil +} + +// LoadAndDelete retrieves and removes a PKCE verifier for the given state. +// Returns the verifier and true if found, or empty string and false if not found. +// This enforces single-use semantics for PKCE verifiers. +func (s *PKCEVerifierStore) LoadAndDelete(state string) (string, bool) { + verifier, err := s.cache.Get(s.ctx, state) + if err != nil { + log.Debugf("PKCE verifier not found for state") + return "", false + } + + if err := s.cache.Delete(s.ctx, state); err != nil { + log.Warnf("Failed to delete PKCE verifier for state: %v", err) + } + + return verifier, true +} diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 308da5e2f..e2d0f1abe 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -82,20 +82,12 @@ type ProxyServiceServer struct { // OIDC configuration for proxy authentication oidcConfig ProxyOIDCConfig - // TODO: use database to store these instead? - // pkceVerifiers stores PKCE code verifiers keyed by OAuth state. - // Entries expire after pkceVerifierTTL to prevent unbounded growth. - pkceVerifiers sync.Map - pkceCleanupCancel context.CancelFunc + // Store for PKCE verifiers + pkceVerifierStore *PKCEVerifierStore } const pkceVerifierTTL = 10 * time.Minute -type pkceEntry struct { - verifier string - createdAt time.Time -} - // proxyConnection represents a connected proxy type proxyConnection struct { proxyID string @@ -107,42 +99,21 @@ type proxyConnection struct { } // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { - ctx, cancel := context.WithCancel(context.Background()) +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { + ctx := context.Background() s := &ProxyServiceServer{ accessLogManager: accessLogMgr, oidcConfig: oidcConfig, tokenStore: tokenStore, + pkceVerifierStore: pkceStore, peersManager: peersManager, usersManager: usersManager, proxyManager: proxyMgr, - pkceCleanupCancel: cancel, } - go s.cleanupPKCEVerifiers(ctx) go s.cleanupStaleProxies(ctx) return s } -// cleanupPKCEVerifiers periodically removes expired PKCE verifiers. -func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) { - ticker := time.NewTicker(pkceVerifierTTL) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - now := time.Now() - s.pkceVerifiers.Range(func(key, value any) bool { - if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL { - s.pkceVerifiers.Delete(key) - } - return true - }) - } - } -} - // cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { ticker := time.NewTicker(5 * time.Minute) @@ -159,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { } } -// Close stops background goroutines. -func (s *ProxyServiceServer) Close() { - s.pkceCleanupCancel() -} - func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { s.serviceManager = manager } @@ -790,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum) codeVerifier := oauth2.GenerateVerifier() - s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()}) + if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil { + log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err) + return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err) + } return &proto.GetOIDCURLResponse{ Url: (&oauth2.Config{ @@ -827,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string { // ValidateState validates the state parameter from an OAuth callback. // Returns the original redirect URL if valid, or an error if invalid. func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) { - v, ok := s.pkceVerifiers.LoadAndDelete(state) + verifier, ok := s.pkceVerifierStore.LoadAndDelete(state) if !ok { return "", "", errors.New("no verifier for state") } - entry, ok := v.(pkceEntry) - if !ok { - return "", "", errors.New("invalid verifier for state") - } - if time.Since(entry.createdAt) > pkceVerifierTTL { - return "", "", errors.New("PKCE verifier expired") - } - verifier = entry.verifier // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) parts := strings.Split(state, "|") diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index ddeadac5a..b7abb28b6 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -5,11 +5,10 @@ import ( "crypto/rand" "encoding/base64" "strings" + "sync" "testing" "time" - "sync" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -94,11 +93,16 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda } func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + ctx := context.Background() + tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } s.SetProxyController(newTestProxyController()) @@ -151,11 +155,16 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { } func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + ctx := context.Background() + tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } s.SetProxyController(newTestProxyController()) @@ -185,11 +194,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + ctx := context.Background() + tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) require.NoError(t, err) s := &ProxyServiceServer{ - tokenStore: tokenStore, + tokenStore: tokenStore, + pkceVerifierStore: pkceStore, } s.SetProxyController(newTestProxyController()) @@ -241,10 +255,15 @@ func generateState(s *ProxyServiceServer, redirectURL string) string { } func TestOAuthState_NeverTheSame(t *testing.T) { + ctx := context.Background() + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } redirectURL := "https://app.example.com/callback" @@ -265,31 +284,43 @@ func TestOAuthState_NeverTheSame(t *testing.T) { } func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { + ctx := context.Background() + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } // Old format had only 2 parts: base64(url)|hmac - s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) + require.NoError(t, err) - _, _, err := s.ValidateState("base64url|hmac") + _, _, err = s.ValidateState("base64url|hmac") require.Error(t, err) assert.Contains(t, err.Error(), "invalid state format") } func TestValidateState_RejectsInvalidHMAC(t *testing.T) { + ctx := context.Background() + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ HMACKey: []byte("test-hmac-key"), }, + pkceVerifierStore: pkceStore, } // Store with tampered HMAC - s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) + require.NoError(t, err) - _, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac") + _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac") require.Error(t, err) assert.Contains(t, err.Error(), "invalid state signature") } diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 124ddf620..647e8443b 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -41,7 +41,10 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) require.NoError(t, err) - proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) diff --git a/management/server/account_test.go b/management/server/account_test.go index a073d4fca..fdec43617 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3133,7 +3133,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) if err != nil { return nil, nil, err diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index c7fd08da8..3bed54e80 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -193,6 +193,9 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) require.NoError(t, err) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + usersManager := users.NewManager(testStore) oidcConfig := nbgrpc.ProxyOIDCConfig{ @@ -206,6 +209,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, + pkceStore, oidcConfig, nil, usersManager, diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 1d74f88d5..5e33ad652 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -98,12 +98,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee if err != nil { t.Fatalf("Failed to create proxy token store: %v", err) } + pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + if err != nil { + t.Fatalf("Failed to create PKCE verifier store: %v", err) + } noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { t.Fatalf("Failed to create proxy manager: %v", err) } - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) domainManager := manager.NewManager(store, proxyMgr, permissionsManager) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) if err != nil { diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 3e5a21400..6a0ecce30 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -116,6 +116,9 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) require.NoError(t, err) + pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + // Create real users manager usersManager := users.NewManager(testStore) @@ -131,6 +134,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, + pkceStore, oidcConfig, nil, usersManager,