mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] Legacy to embedded IdP migration tool (#5586)
This commit is contained in:
@@ -25,7 +25,7 @@ import (
|
||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKey `json:"keys"`
|
||||
expiresInTime time.Time
|
||||
ExpiresInTime time.Time `json:"-"`
|
||||
}
|
||||
|
||||
// The supported elliptic curves types
|
||||
@@ -53,12 +53,17 @@ type JSONWebKey struct {
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage)
|
||||
// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys.
|
||||
type KeyFetcher func(ctx context.Context) (*Jwks, error)
|
||||
|
||||
type Validator struct {
|
||||
lock sync.Mutex
|
||||
issuer string
|
||||
audienceList []string
|
||||
keysLocation string
|
||||
idpSignkeyRefreshEnabled bool
|
||||
keyFetcher KeyFetcher
|
||||
keys *Jwks
|
||||
lastForcedRefresh time.Time
|
||||
}
|
||||
@@ -85,10 +90,39 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
|
||||
}
|
||||
}
|
||||
|
||||
// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the
|
||||
// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP.
|
||||
func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator {
|
||||
ctx := context.Background()
|
||||
keys, err := keyFetcher(ctx)
|
||||
if err != nil {
|
||||
log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err)
|
||||
}
|
||||
if keys == nil {
|
||||
keys = &Jwks{}
|
||||
}
|
||||
|
||||
return &Validator{
|
||||
keys: keys,
|
||||
issuer: issuer,
|
||||
audienceList: audienceList,
|
||||
idpSignkeyRefreshEnabled: true,
|
||||
keyFetcher: keyFetcher,
|
||||
}
|
||||
}
|
||||
|
||||
// forcedRefreshCooldown is the minimum time between forced key refreshes
|
||||
// to prevent abuse from invalid tokens with fake kid values
|
||||
const forcedRefreshCooldown = 30 * time.Second
|
||||
|
||||
// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP.
|
||||
func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) {
|
||||
if v.keyFetcher != nil {
|
||||
return v.keyFetcher(ctx)
|
||||
}
|
||||
return getPemKeys(v.keysLocation)
|
||||
}
|
||||
|
||||
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
// If keys are rotated, verify the keys prior to token validation
|
||||
@@ -131,13 +165,13 @@ func (v *Validator) refreshKeys(ctx context.Context) {
|
||||
v.lock.Lock()
|
||||
defer v.lock.Unlock()
|
||||
|
||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
||||
refreshedKeys, err := v.fetchKeys(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
|
||||
v.keys = refreshedKeys
|
||||
}
|
||||
|
||||
@@ -155,13 +189,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool {
|
||||
|
||||
log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh")
|
||||
|
||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
||||
refreshedKeys, err := v.fetchKeys(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
|
||||
v.keys = refreshedKeys
|
||||
v.lastForcedRefresh = time.Now()
|
||||
return true
|
||||
@@ -203,7 +237,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To
|
||||
|
||||
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
|
||||
func (jwks *Jwks) stillValid() bool {
|
||||
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
|
||||
return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime)
|
||||
}
|
||||
|
||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
@@ -227,7 +261,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
|
||||
cacheControlHeader := resp.Header.Get("Cache-Control")
|
||||
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
|
||||
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user