diff --git a/management/server/config.go b/management/server/config.go index 9ec16b3e8..5fd7fd8d4 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -80,6 +80,8 @@ type HttpServerConfig struct { AuthKeysLocation string // OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration OIDCConfigEndpoint string + // KeyRotationEnabled identifies the signing key is currently being rotated or not + KeyRotationEnabled bool } // Host represents a Wiretrustee host (e.g. STUN, TURN, Signal) diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index 147f8f2eb..d0143c3e5 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -12,6 +12,9 @@ import ( "fmt" "math/big" "net/http" + "strconv" + "strings" + "time" "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" @@ -45,7 +48,8 @@ type Options struct { // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { - Keys []JSONWebKey `json:"keys"` + Keys []JSONWebKey `json:"keys"` + expiresInTime time.Time } // JSONWebKey is a representation of a Jason Web Key @@ -64,7 +68,7 @@ type JWTValidator struct { } // NewJWTValidator constructor -func NewJWTValidator(issuer string, audienceList []string, keysLocation string) (*JWTValidator, error) { +func NewJWTValidator(issuer string, audienceList []string, keysLocation string, keyRotationEnabled bool) (*JWTValidator, error) { keys, err := getPemKeys(keysLocation) if err != nil { return nil, err @@ -89,6 +93,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string) return token, errors.New("invalid issuer") } + // If keys are rotated, verify the keys prior to token validation + if keyRotationEnabled { + // If the keys are invalid, retrieve new ones + if !keys.stillValid() { + + keys, err = getPemKeys(keysLocation) + if err != nil { + log.Errorf("cannot get JSONWebKey: %v", err) + return nil, err + } + } + } + cert, err := getPemCert(token, keys) if err != nil { return nil, err @@ -154,6 +171,11 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { return parsedToken, nil } +// stillValid returns true if the JSONWebKey still valid and have enough time to be used +func (jwks *Jwks) stillValid() bool { + return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) +} + func getPemKeys(keysLocation string) (*Jwks, error) { resp, err := http.Get(keysLocation) if err != nil { @@ -167,6 +189,10 @@ func getPemKeys(keysLocation string) (*Jwks, error) { return jwks, err } + cacheControlHeader := resp.Header.Get("Cache-Control") + expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) + jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) + return jwks, err } @@ -248,3 +274,26 @@ func convertExponentStringToInt(stringExponent string) (int, error) { return int(exponent), nil } + +// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header +func getMaxAgeFromCacheHeader(cacheControl string) int { + // Split into individual directives + directives := strings.Split(cacheControl, ",") + + for _, directive := range directives { + directive = strings.TrimSpace(directive) + if strings.HasPrefix(directive, "max-age=") { + // Extract the max-age value + maxAgeStr := strings.TrimPrefix(directive, "max-age=") + maxAge, err := strconv.Atoi(maxAgeStr) + if err != nil { + log.Debugf("error parsing max-age: %v", err) + return 0 + } + + return maxAge + } + } + + return 0 +}