[management] Legacy to embedded IdP migration tool (#5586)

This commit is contained in:
shuuri-labs
2026-04-01 12:53:19 +01:00
committed by GitHub
parent 4d3e2f8ad3
commit 940f530ac2
24 changed files with 4023 additions and 31 deletions

View File

@@ -25,7 +25,7 @@ import (
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct {
Keys []JSONWebKey `json:"keys"`
expiresInTime time.Time
ExpiresInTime time.Time `json:"-"`
}
// The supported elliptic curves types
@@ -53,12 +53,17 @@ type JSONWebKey struct {
X5c []string `json:"x5c"`
}
// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage)
// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys.
type KeyFetcher func(ctx context.Context) (*Jwks, error)
type Validator struct {
lock sync.Mutex
issuer string
audienceList []string
keysLocation string
idpSignkeyRefreshEnabled bool
keyFetcher KeyFetcher
keys *Jwks
lastForcedRefresh time.Time
}
@@ -85,10 +90,39 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
}
}
// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the
// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP.
func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator {
ctx := context.Background()
keys, err := keyFetcher(ctx)
if err != nil {
log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err)
}
if keys == nil {
keys = &Jwks{}
}
return &Validator{
keys: keys,
issuer: issuer,
audienceList: audienceList,
idpSignkeyRefreshEnabled: true,
keyFetcher: keyFetcher,
}
}
// forcedRefreshCooldown is the minimum time between forced key refreshes
// to prevent abuse from invalid tokens with fake kid values
const forcedRefreshCooldown = 30 * time.Second
// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP.
func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) {
if v.keyFetcher != nil {
return v.keyFetcher(ctx)
}
return getPemKeys(v.keysLocation)
}
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
// If keys are rotated, verify the keys prior to token validation
@@ -131,13 +165,13 @@ func (v *Validator) refreshKeys(ctx context.Context) {
v.lock.Lock()
defer v.lock.Unlock()
refreshedKeys, err := getPemKeys(v.keysLocation)
refreshedKeys, err := v.fetchKeys(ctx)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
return
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
v.keys = refreshedKeys
}
@@ -155,13 +189,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool {
log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh")
refreshedKeys, err := getPemKeys(v.keysLocation)
refreshedKeys, err := v.fetchKeys(ctx)
if err != nil {
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
return false
}
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
v.keys = refreshedKeys
v.lastForcedRefresh = time.Now()
return true
@@ -203,7 +237,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
@@ -227,7 +261,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, nil
}