mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
3 Commits
feature/up
...
feature/us
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d3e5f508c | ||
|
|
8d09ded1db | ||
|
|
a49a052f05 |
2
go.mod
2
go.mod
@@ -49,6 +49,7 @@ require (
|
|||||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||||
github.com/fsnotify/fsnotify v1.9.0
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gliderlabs/ssh v0.3.8
|
github.com/gliderlabs/ssh v0.3.8
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.3
|
||||||
github.com/godbus/dbus/v5 v5.1.0
|
github.com/godbus/dbus/v5 v5.1.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||||
github.com/golang/mock v1.6.0
|
github.com/golang/mock v1.6.0
|
||||||
@@ -181,7 +182,6 @@ require (
|
|||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||||
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
|
||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
|
||||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
|
||||||
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
|
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package dex
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -19,10 +20,13 @@ import (
|
|||||||
"github.com/dexidp/dex/server"
|
"github.com/dexidp/dex/server"
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/dexidp/dex/storage/sql"
|
"github.com/dexidp/dex/storage/sql"
|
||||||
|
jose "github.com/go-jose/go-jose/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config matches what management/internals/server/server.go expects
|
// Config matches what management/internals/server/server.go expects
|
||||||
@@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string {
|
|||||||
}
|
}
|
||||||
return issuer + "/auth"
|
return issuer + "/auth"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks.
|
||||||
|
// This avoids HTTP round-trips when the embedded IDP is co-located with the management server.
|
||||||
|
// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic:
|
||||||
|
// SigningKeyPub first, then all VerificationKeys, serialized via go-jose.
|
||||||
|
func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) {
|
||||||
|
keys, err := p.storage.GetKeys(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get keys from storage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keys.SigningKeyPub == nil {
|
||||||
|
return nil, fmt.Errorf("no public keys found in storage")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the key set exactly as Dex's localSigner.ValidationKeys does:
|
||||||
|
// signing key first, then all verification (rotated) keys.
|
||||||
|
joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1)
|
||||||
|
joseKeys = append(joseKeys, *keys.SigningKeyPub)
|
||||||
|
for _, vk := range keys.VerificationKeys {
|
||||||
|
if vk.PublicKey != nil {
|
||||||
|
joseKeys = append(joseKeys, *vk.PublicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize through go-jose (same as Dex's handlePublicKeys handler)
|
||||||
|
// then deserialize into our Jwks type, so the JSON field mapping is identical
|
||||||
|
// to what the /keys HTTP endpoint would return.
|
||||||
|
joseSet := jose.JSONWebKeySet{Keys: joseKeys}
|
||||||
|
data, err := json.Marshal(joseSet)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal JWKS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks := &nbjwt.Jwks{}
|
||||||
|
if err := json.Unmarshal(data, jwks); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks.ExpiresInTime = keys.NextRotation
|
||||||
|
|
||||||
|
return jwks, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/job"
|
"github.com/netbirdio/netbird/management/server/job"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||||
@@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
|||||||
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
|
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
|
||||||
issuer := s.Config.HttpConfig.AuthIssuer
|
issuer := s.Config.HttpConfig.AuthIssuer
|
||||||
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
|
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
|
||||||
|
var keyFetcher nbjwt.KeyFetcher
|
||||||
|
|
||||||
// Use embedded IdP configuration if available
|
// Use embedded IdP configuration if available
|
||||||
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
|
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
|
||||||
@@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
|||||||
if len(audiences) > 0 {
|
if len(audiences) > 0 {
|
||||||
audience = audiences[0] // Use the first client ID as the primary audience
|
audience = audiences[0] // Use the first client ID as the primary audience
|
||||||
}
|
}
|
||||||
// Use localhost keys location for internal validation (management has embedded Dex)
|
keyFetcher = oauthProvider.GetKeyFetcher()
|
||||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
// Fall back to default keys location if direct key fetching is not available
|
||||||
|
if keyFetcher == nil {
|
||||||
|
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||||
|
}
|
||||||
signingKeyRefreshEnabled = true
|
signingKeyRefreshEnabled = true
|
||||||
issuer = oauthProvider.GetIssuer()
|
issuer = oauthProvider.GetIssuer()
|
||||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||||
@@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
|||||||
keysLocation,
|
keysLocation,
|
||||||
userIDClaim,
|
userIDClaim,
|
||||||
audiences,
|
audiences,
|
||||||
signingKeyRefreshEnabled)
|
signingKeyRefreshEnabled,
|
||||||
|
keyFetcher)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,15 +33,20 @@ type manager struct {
|
|||||||
extractor *nbjwt.ClaimsExtractor
|
extractor *nbjwt.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
|
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
|
||||||
// @note if invalid/missing parameters are sent the validator will instantiate
|
var jwtValidator *nbjwt.Validator
|
||||||
// but it will fail when validating and parsing the token
|
if keyFetcher != nil {
|
||||||
jwtValidator := nbjwt.NewValidator(
|
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
|
||||||
issuer,
|
} else {
|
||||||
allAudiences,
|
// @note if invalid/missing parameters are sent the validator will instantiate
|
||||||
keysLocation,
|
// but it will fail when validating and parsing the token
|
||||||
idpRefreshKeys,
|
jwtValidator = nbjwt.NewValidator(
|
||||||
)
|
issuer,
|
||||||
|
allAudiences,
|
||||||
|
keysLocation,
|
||||||
|
idpRefreshKeys,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
claimsExtractor := nbjwt.NewClaimsExtractor(
|
claimsExtractor := nbjwt.NewClaimsExtractor(
|
||||||
nbjwt.WithAudience(audience),
|
nbjwt.WithAudience(audience),
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
|||||||
t.Fatalf("Error when saving account: %s", err)
|
t.Fatalf("Error when saving account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||||
|
|
||||||
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
|
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
|
|||||||
t.Fatalf("Error when saving account: %s", err)
|
t.Fatalf("Error when saving account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||||
|
|
||||||
err = manager.MarkPATUsed(context.Background(), "tokenId")
|
err = manager.MarkPATUsed(context.Background(), "tokenId")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
|||||||
// these tests only assert groups are parsed from token as per account settings
|
// these tests only assert groups are parsed from token as per account settings
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
|
||||||
|
|
||||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||||
|
|
||||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||||
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
|||||||
keyId := "test-key"
|
keyId := "test-key"
|
||||||
|
|
||||||
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
|
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
|
||||||
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
|
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
|
||||||
|
|
||||||
customClaim := func(name string) string {
|
customClaim := func(name string) string {
|
||||||
return fmt.Sprintf("%s/%s", audience, name)
|
return fmt.Sprintf("%s/%s", audience, name)
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
am.SetServiceManager(serviceManager)
|
am.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||||
authManagerMock := &serverauth.MockManager{
|
authManagerMock := &serverauth.MockManager{
|
||||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||||
@@ -248,7 +248,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
am.SetServiceManager(serviceManager)
|
am.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||||
authManagerMock := &serverauth.MockManager{
|
authManagerMock := &serverauth.MockManager{
|
||||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/idp/dex"
|
"github.com/netbirdio/netbird/idp/dex"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -193,6 +194,9 @@ type OAuthConfigProvider interface {
|
|||||||
// Management server has embedded Dex and can validate tokens via localhost,
|
// Management server has embedded Dex and can validate tokens via localhost,
|
||||||
// avoiding external network calls and DNS resolution issues during startup.
|
// avoiding external network calls and DNS resolution issues during startup.
|
||||||
GetLocalKeysLocation() string
|
GetLocalKeysLocation() string
|
||||||
|
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
|
||||||
|
// or nil if direct key fetching is not supported (falls back to HTTP).
|
||||||
|
GetKeyFetcher() nbjwt.KeyFetcher
|
||||||
GetClientIDs() []string
|
GetClientIDs() []string
|
||||||
GetUserIDClaim() string
|
GetUserIDClaim() string
|
||||||
GetTokenEndpoint() string
|
GetTokenEndpoint() string
|
||||||
@@ -593,6 +597,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
|
|||||||
return m.config.CLIRedirectURIs
|
return m.config.CLIRedirectURIs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
|
||||||
|
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
|
||||||
|
return m.provider.GetJWKS
|
||||||
|
}
|
||||||
|
|
||||||
// GetKeysLocation returns the JWKS endpoint URL for token validation.
|
// GetKeysLocation returns the JWKS endpoint URL for token validation.
|
||||||
func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
||||||
return m.provider.GetKeysLocation()
|
return m.provider.GetKeysLocation()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import (
|
|||||||
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||||
type Jwks struct {
|
type Jwks struct {
|
||||||
Keys []JSONWebKey `json:"keys"`
|
Keys []JSONWebKey `json:"keys"`
|
||||||
expiresInTime time.Time
|
ExpiresInTime time.Time `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// The supported elliptic curves types
|
// The supported elliptic curves types
|
||||||
@@ -53,12 +53,17 @@ type JSONWebKey struct {
|
|||||||
X5c []string `json:"x5c"`
|
X5c []string `json:"x5c"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage)
|
||||||
|
// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys.
|
||||||
|
type KeyFetcher func(ctx context.Context) (*Jwks, error)
|
||||||
|
|
||||||
type Validator struct {
|
type Validator struct {
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
issuer string
|
issuer string
|
||||||
audienceList []string
|
audienceList []string
|
||||||
keysLocation string
|
keysLocation string
|
||||||
idpSignkeyRefreshEnabled bool
|
idpSignkeyRefreshEnabled bool
|
||||||
|
keyFetcher KeyFetcher
|
||||||
keys *Jwks
|
keys *Jwks
|
||||||
lastForcedRefresh time.Time
|
lastForcedRefresh time.Time
|
||||||
}
|
}
|
||||||
@@ -85,10 +90,39 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the
|
||||||
|
// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP.
|
||||||
|
func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator {
|
||||||
|
ctx := context.Background()
|
||||||
|
keys, err := keyFetcher(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err)
|
||||||
|
}
|
||||||
|
if keys == nil {
|
||||||
|
keys = &Jwks{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Validator{
|
||||||
|
keys: keys,
|
||||||
|
issuer: issuer,
|
||||||
|
audienceList: audienceList,
|
||||||
|
idpSignkeyRefreshEnabled: true,
|
||||||
|
keyFetcher: keyFetcher,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// forcedRefreshCooldown is the minimum time between forced key refreshes
|
// forcedRefreshCooldown is the minimum time between forced key refreshes
|
||||||
// to prevent abuse from invalid tokens with fake kid values
|
// to prevent abuse from invalid tokens with fake kid values
|
||||||
const forcedRefreshCooldown = 30 * time.Second
|
const forcedRefreshCooldown = 30 * time.Second
|
||||||
|
|
||||||
|
// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP.
|
||||||
|
func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) {
|
||||||
|
if v.keyFetcher != nil {
|
||||||
|
return v.keyFetcher(ctx)
|
||||||
|
}
|
||||||
|
return getPemKeys(v.keysLocation)
|
||||||
|
}
|
||||||
|
|
||||||
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||||
return func(token *jwt.Token) (interface{}, error) {
|
return func(token *jwt.Token) (interface{}, error) {
|
||||||
// If keys are rotated, verify the keys prior to token validation
|
// If keys are rotated, verify the keys prior to token validation
|
||||||
@@ -131,13 +165,13 @@ func (v *Validator) refreshKeys(ctx context.Context) {
|
|||||||
v.lock.Lock()
|
v.lock.Lock()
|
||||||
defer v.lock.Unlock()
|
defer v.lock.Unlock()
|
||||||
|
|
||||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
refreshedKeys, err := v.fetchKeys(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
|
||||||
v.keys = refreshedKeys
|
v.keys = refreshedKeys
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,13 +189,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool {
|
|||||||
|
|
||||||
log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh")
|
log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh")
|
||||||
|
|
||||||
refreshedKeys, err := getPemKeys(v.keysLocation)
|
refreshedKeys, err := v.fetchKeys(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC())
|
||||||
v.keys = refreshedKeys
|
v.keys = refreshedKeys
|
||||||
v.lastForcedRefresh = time.Now()
|
v.lastForcedRefresh = time.Now()
|
||||||
return true
|
return true
|
||||||
@@ -203,7 +237,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To
|
|||||||
|
|
||||||
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
|
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
|
||||||
func (jwks *Jwks) stillValid() bool {
|
func (jwks *Jwks) stillValid() bool {
|
||||||
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
|
return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||||
@@ -227,7 +261,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
|
|||||||
|
|
||||||
cacheControlHeader := resp.Header.Get("Cache-Control")
|
cacheControlHeader := resp.Header.Get("Cache-Control")
|
||||||
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
|
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
|
||||||
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
return jwks, nil
|
return jwks, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user