[management] use the cache for the pkce state (#5516)

This commit is contained in:
Pascal Fischer
2026-03-09 12:23:06 +01:00
committed by GitHub
parent 3acd86e346
commit 30c02ab78c
11 changed files with 152 additions and 72 deletions

View File

@@ -82,20 +82,12 @@ type ProxyServiceServer struct {
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
// TODO: use database to store these instead?
// pkceVerifiers stores PKCE code verifiers keyed by OAuth state.
// Entries expire after pkceVerifierTTL to prevent unbounded growth.
pkceVerifiers sync.Map
pkceCleanupCancel context.CancelFunc
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
}
const pkceVerifierTTL = 10 * time.Minute
type pkceEntry struct {
verifier string
createdAt time.Time
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
@@ -107,42 +99,21 @@ type proxyConnection struct {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx := context.Background()
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
pkceVerifierStore: pkceStore,
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel,
}
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s
}
// cleanupPKCEVerifiers periodically removes expired PKCE verifiers.
func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
ticker := time.NewTicker(pkceVerifierTTL)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
s.pkceVerifiers.Range(func(key, value any) bool {
if entry, ok := value.(pkceEntry); ok && now.Sub(entry.createdAt) > pkceVerifierTTL {
s.pkceVerifiers.Delete(key)
}
return true
})
}
}
}
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
@@ -159,11 +130,6 @@ func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager
}
@@ -790,7 +756,10 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
codeVerifier := oauth2.GenerateVerifier()
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
if err := s.pkceVerifierStore.Store(state, codeVerifier, pkceVerifierTTL); err != nil {
log.WithContext(ctx).Errorf("failed to store PKCE verifier: %v", err)
return nil, status.Errorf(codes.Internal, "store PKCE verifier: %v", err)
}
return &proto.GetOIDCURLResponse{
Url: (&oauth2.Config{
@@ -827,18 +796,10 @@ func (s *ProxyServiceServer) generateHMAC(input string) string {
// ValidateState validates the state parameter from an OAuth callback.
// Returns the original redirect URL if valid, or an error if invalid.
func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL string, err error) {
v, ok := s.pkceVerifiers.LoadAndDelete(state)
verifier, ok := s.pkceVerifierStore.LoadAndDelete(state)
if !ok {
return "", "", errors.New("no verifier for state")
}
entry, ok := v.(pkceEntry)
if !ok {
return "", "", errors.New("invalid verifier for state")
}
if time.Since(entry.createdAt) > pkceVerifierTTL {
return "", "", errors.New("PKCE verifier expired")
}
verifier = entry.verifier
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
parts := strings.Split(state, "|")