mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Generate a random nonce to ensure each OIDC request gets a unique state
This commit is contained in:
@@ -3,6 +3,7 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
@@ -747,10 +748,20 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
|
||||
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
||||
}
|
||||
|
||||
// Generate a random nonce to ensure each OIDC request gets a unique state.
|
||||
// Without this, multiple requests to the same URL would generate the same state
|
||||
// but different PKCE verifiers, causing the later verifier to overwrite the earlier one.
|
||||
nonce := make([]byte, 16)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "generate nonce: %v", err)
|
||||
}
|
||||
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
||||
|
||||
// Using an HMAC here to avoid redirection state being modified.
|
||||
// State format: base64(redirectURL)|hmac
|
||||
hmacSum := s.generateHMAC(redirectURL.String())
|
||||
state := fmt.Sprintf("%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), hmacSum)
|
||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||
payload := redirectURL.String() + "|" + nonceB64
|
||||
hmacSum := s.generateHMAC(payload)
|
||||
state := fmt.Sprintf("%s|%s|%s", base64.URLEncoding.EncodeToString([]byte(redirectURL.String())), nonceB64, hmacSum)
|
||||
|
||||
codeVerifier := oauth2.GenerateVerifier()
|
||||
s.pkceVerifiers.Store(state, pkceEntry{verifier: codeVerifier, createdAt: time.Now()})
|
||||
@@ -803,13 +814,15 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
}
|
||||
verifier = entry.verifier
|
||||
|
||||
// State format: base64(redirectURL)|nonce|hmac(redirectURL|nonce)
|
||||
parts := strings.Split(state, "|")
|
||||
if len(parts) != 2 {
|
||||
if len(parts) != 3 {
|
||||
return "", "", errors.New("invalid state format")
|
||||
}
|
||||
|
||||
encodedURL := parts[0]
|
||||
providedHMAC := parts[1]
|
||||
nonce := parts[1]
|
||||
providedHMAC := parts[2]
|
||||
|
||||
redirectURLBytes, err := base64.URLEncoding.DecodeString(encodedURL)
|
||||
if err != nil {
|
||||
@@ -817,10 +830,11 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
||||
}
|
||||
redirectURL = string(redirectURLBytes)
|
||||
|
||||
expectedHMAC := s.generateHMAC(redirectURL)
|
||||
payload := redirectURL + "|" + nonce
|
||||
expectedHMAC := s.generateHMAC(payload)
|
||||
|
||||
if !hmac.Equal([]byte(providedHMAC), []byte(expectedHMAC)) {
|
||||
return "", "", fmt.Errorf("invalid state signature")
|
||||
return "", "", errors.New("invalid state signature")
|
||||
}
|
||||
|
||||
return verifier, redirectURL, nil
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -162,3 +165,68 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
|
||||
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
|
||||
}
|
||||
|
||||
// generateState creates a state using the same format as GetOIDCURL.
|
||||
func generateState(s *ProxyServiceServer, redirectURL string) string {
|
||||
nonce := make([]byte, 16)
|
||||
rand.Read(nonce)
|
||||
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
|
||||
|
||||
payload := redirectURL + "|" + nonceB64
|
||||
hmacSum := s.generateHMAC(payload)
|
||||
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
|
||||
}
|
||||
|
||||
func TestOAuthState_NeverTheSame(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
}
|
||||
|
||||
redirectURL := "https://app.example.com/callback"
|
||||
|
||||
// Generate 100 states for the same redirect URL
|
||||
states := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
state := generateState(s, redirectURL)
|
||||
|
||||
// State must have 3 parts: base64(url)|nonce|hmac
|
||||
parts := strings.Split(state, "|")
|
||||
require.Equal(t, 3, len(parts), "state must have 3 parts")
|
||||
|
||||
// State must be unique
|
||||
require.False(t, states[state], "state %d is a duplicate", i)
|
||||
states[state] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
}
|
||||
|
||||
// Old format had only 2 parts: base64(url)|hmac
|
||||
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
|
||||
_, _, err := s.ValidateState("base64url|hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state format")
|
||||
}
|
||||
|
||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||
s := &ProxyServiceServer{
|
||||
oidcConfig: ProxyOIDCConfig{
|
||||
HMACKey: []byte("test-hmac-key"),
|
||||
},
|
||||
}
|
||||
|
||||
// Store with tampered HMAC
|
||||
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
|
||||
|
||||
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid state signature")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user