diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index aae6ef2f9..cecde960d 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -3,6 +3,7 @@ package grpc import ( "context" "crypto/hmac" + "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" @@ -747,10 +748,20 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU scopes = []string{oidc.ScopeOpenID, "profile", "email"} } + // Generate a random nonce to ensure each OIDC request gets a unique state. + // Without this, multiple requests to the same URL would generate the same state + // but different PKCE verifiers, causing the later verifier to overwrite the earlier one. + nonce := make([]byte, 16) + if _, err := rand.Read(nonce); err != nil { + return nil, status.Errorf(codes.Internal, "generate nonce: %v", err) + } + nonceB64 := base64.URLEncoding.EncodeToString(nonce) + // Using an HMAC here to avoid redirection state being modified. - // State format: base64(redirectURL)|hmac - hmacSum := s.generateHMAC(redirectURL.String()) - state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum) + // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) + payload := redirectURL.String() + "|" + nonceB64 + hmacSum := s.generateHMAC(payload) + state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum) codeVerifier := oauth2.GenerateVerifier() s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()}) @@ -803,13 +814,15 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL } verifier = entry.verifier + // State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce) parts := strings.Split(state, "|") - if len(parts) != 2 { + if len(parts) != 3 { return "", "", errors.New("invalid state format") } encodedURL := parts[0] - providedHMAC := parts[1] + nonce := parts[1] + providedHMAC := parts[2] redirectURLBytes, err := base64.URLEncoding.DecodeString(encodedURL) if err != nil { @@ -817,10 +830,11 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL } redirectURL = string(redirectURLBytes) - expectedHMAC := s.generateHMAC(redirectURL) + payload := redirectURL + "|" + nonce + expectedHMAC := s.generateHMAC(payload) if !hmac.Equal([]byte(providedHMAC), []byte(expectedHMAC)) { - return "", "", fmt.Errorf("invalid state signature") + return "", "", errors.New("invalid state signature") } return verifier, redirectURL, nil diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 589c57611..060af6554 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -1,6 +1,9 @@ package grpc import ( + "crypto/rand" + "encoding/base64" + "strings" "sync" "testing" "time" @@ -162,3 +165,68 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1")) assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1")) } + +// generateState creates a state using the same format as GetOIDCURL. +func generateState(s *ProxyServiceServer, redirectURL string) string { + nonce := make([]byte, 16) + rand.Read(nonce) + nonceB64 := base64.URLEncoding.EncodeToString(nonce) + + payload := redirectURL + "|" + nonceB64 + hmacSum := s.generateHMAC(payload) + return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum +} + +func TestOAuthState_NeverTheSame(t *testing.T) { + s := &ProxyServiceServer{ + oidcConfig: ProxyOIDCConfig{ + HMACKey: []byte("test-hmac-key"), + }, + } + + redirectURL := "https://app.example.com/callback" + + // Generate 100 states for the same redirect URL + states := make(map[string]bool) + for i := 0; i < 100; i++ { + state := generateState(s, redirectURL) + + // State must have 3 parts: base64(url)|nonce|hmac + parts := strings.Split(state, "|") + require.Equal(t, 3, len(parts), "state must have 3 parts") + + // State must be unique + require.False(t, states[state], "state %d is a duplicate", i) + states[state] = true + } +} + +func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { + s := &ProxyServiceServer{ + oidcConfig: ProxyOIDCConfig{ + HMACKey: []byte("test-hmac-key"), + }, + } + + // Old format had only 2 parts: base64(url)|hmac + s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + + _, _, err := s.ValidateState("base64url|hmac") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid state format") +} + +func TestValidateState_RejectsInvalidHMAC(t *testing.T) { + s := &ProxyServiceServer{ + oidcConfig: ProxyOIDCConfig{ + HMACKey: []byte("test-hmac-key"), + }, + } + + // Store with tampered HMAC + s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()}) + + _, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid state signature") +}