diff --git a/idp/dex/provider.go b/idp/dex/provider.go index 68fe48486..24aed1b99 100644 --- a/idp/dex/provider.go +++ b/idp/dex/provider.go @@ -4,6 +4,7 @@ package dex import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "log/slog" @@ -19,10 +20,13 @@ import ( "github.com/dexidp/dex/server" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/sql" + jose "github.com/go-jose/go-jose/v4" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc" + + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) // Config matches what management/internals/server/server.go expects @@ -666,3 +670,46 @@ func (p *Provider) GetAuthorizationEndpoint() string { } return issuer + "/auth" } + +// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks. +// This avoids HTTP round-trips when the embedded IDP is co-located with the management server. +// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic: +// SigningKeyPub first, then all VerificationKeys, serialized via go-jose. +func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) { + keys, err := p.storage.GetKeys(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get keys from storage: %w", err) + } + + if keys.SigningKeyPub == nil { + return nil, fmt.Errorf("no public keys found in storage") + } + + // Build the key set exactly as Dex's localSigner.ValidationKeys does: + // signing key first, then all verification (rotated) keys. + joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1) + joseKeys = append(joseKeys, *keys.SigningKeyPub) + for _, vk := range keys.VerificationKeys { + if vk.PublicKey != nil { + joseKeys = append(joseKeys, *vk.PublicKey) + } + } + + // Serialize through go-jose (same as Dex's handlePublicKeys handler) + // then deserialize into our Jwks type, so the JSON field mapping is identical + // to what the /keys HTTP endpoint would return. + joseSet := jose.JSONWebKeySet{Keys: joseKeys} + data, err := json.Marshal(joseSet) + if err != nil { + return nil, fmt.Errorf("failed to marshal JWKS: %w", err) + } + + jwks := &nbjwt.Jwks{} + if err := json.Unmarshal(data, jwks); err != nil { + return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err) + } + + jwks.ExpiresInTime = keys.NextRotation + + return jwks, nil +} diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 62ed659c0..c7eab3d19 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { @@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager { signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled issuer := s.Config.HttpConfig.AuthIssuer userIDClaim := s.Config.HttpConfig.AuthUserIDClaim + var keyFetcher nbjwt.KeyFetcher // Use embedded IdP configuration if available if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil { @@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager { if len(audiences) > 0 { audience = audiences[0] // Use the first client ID as the primary audience } - // Use localhost keys location for internal validation (management has embedded Dex) - keysLocation = oauthProvider.GetLocalKeysLocation() + keyFetcher = oauthProvider.GetKeyFetcher() + // Fall back to default keys location if direct key fetching is not available + if keyFetcher == nil { + keysLocation = oauthProvider.GetLocalKeysLocation() + } signingKeyRefreshEnabled = true issuer = oauthProvider.GetIssuer() userIDClaim = oauthProvider.GetUserIDClaim() @@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager { keysLocation, userIDClaim, audiences, - signingKeyRefreshEnabled) + signingKeyRefreshEnabled, + keyFetcher) }) } diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index 76cc750b6..27346a604 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -33,15 +33,20 @@ type manager struct { extractor *nbjwt.ClaimsExtractor } -func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager { - // @note if invalid/missing parameters are sent the validator will instantiate - // but it will fail when validating and parsing the token - jwtValidator := nbjwt.NewValidator( - issuer, - allAudiences, - keysLocation, - idpRefreshKeys, - ) +func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager { + var jwtValidator *nbjwt.Validator + if keyFetcher != nil { + jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher) + } else { + // @note if invalid/missing parameters are sent the validator will instantiate + // but it will fail when validating and parsing the token + jwtValidator = nbjwt.NewValidator( + issuer, + allAudiences, + keysLocation, + idpRefreshKeys, + ) + } claimsExtractor := nbjwt.NewClaimsExtractor( nbjwt.WithAudience(audience), diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index b9f091b1e..469737f47 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) user, pat, _, _, err := manager.GetPATInfo(context.Background(), token) if err != nil { @@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) err = manager.MarkPATUsed(context.Background(), "tokenId") if err != nil { @@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { // these tests only assert groups are parsed from token as per account settings token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}) - manager := auth.NewManager(store, "", "", "", "", []string{}, false) + manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil) t.Run("JWT groups disabled", func(t *testing.T) { userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token) @@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { keyId := "test-key" // note, we can use a nil store because ValidateAndParseToken does not use it in it's flow - manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false) + manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil) customClaim := func(name string) string { return fmt.Sprintf("%s/%s", audience, name) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index c6e57b1be..d9d85a0a2 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -119,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee am.SetServiceManager(serviceManager) // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil) authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, @@ -248,7 +248,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin am.SetServiceManager(serviceManager) // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil) authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index 2cc7b9743..179b89298 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/idp/dex" "github.com/netbirdio/netbird/management/server/telemetry" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -193,6 +194,9 @@ type OAuthConfigProvider interface { // Management server has embedded Dex and can validate tokens via localhost, // avoiding external network calls and DNS resolution issues during startup. GetLocalKeysLocation() string + // GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage, + // or nil if direct key fetching is not supported (falls back to HTTP). + GetKeyFetcher() nbjwt.KeyFetcher GetClientIDs() []string GetUserIDClaim() string GetTokenEndpoint() string @@ -593,6 +597,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string { return m.config.CLIRedirectURIs } +// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage. +func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher { + return m.provider.GetJWKS +} + // GetKeysLocation returns the JWKS endpoint URL for token validation. func (m *EmbeddedIdPManager) GetKeysLocation() string { return m.provider.GetKeysLocation() diff --git a/shared/auth/jwt/validator.go b/shared/auth/jwt/validator.go index aeaa5842c..5a014de2d 100644 --- a/shared/auth/jwt/validator.go +++ b/shared/auth/jwt/validator.go @@ -25,7 +25,7 @@ import ( // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKey `json:"keys"` - expiresInTime time.Time + ExpiresInTime time.Time `json:"-"` } // The supported elliptic curves types @@ -53,12 +53,17 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } +// KeyFetcher is a function that retrieves JWKS keys directly (e.g., from Dex storage) +// bypassing HTTP. When set on a Validator, it is used instead of the HTTP-based getPemKeys. +type KeyFetcher func(ctx context.Context) (*Jwks, error) + type Validator struct { lock sync.Mutex issuer string audienceList []string keysLocation string idpSignkeyRefreshEnabled bool + keyFetcher KeyFetcher keys *Jwks lastForcedRefresh time.Time } @@ -85,10 +90,36 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp } } +// NewValidatorWithKeyFetcher creates a Validator that fetches keys directly using the +// provided KeyFetcher (e.g., from Dex storage) instead of via HTTP. +func NewValidatorWithKeyFetcher(issuer string, audienceList []string, keyFetcher KeyFetcher) *Validator { + ctx := context.Background() + keys, err := keyFetcher(ctx) + if err != nil { + log.Warnf("could not get keys from key fetcher: %s, it will try again on the next http request", err) + } + + return &Validator{ + keys: keys, + issuer: issuer, + audienceList: audienceList, + idpSignkeyRefreshEnabled: true, + keyFetcher: keyFetcher, + } +} + // forcedRefreshCooldown is the minimum time between forced key refreshes // to prevent abuse from invalid tokens with fake kid values const forcedRefreshCooldown = 30 * time.Second +// fetchKeys retrieves keys using the keyFetcher if available, otherwise falls back to HTTP. +func (v *Validator) fetchKeys(ctx context.Context) (*Jwks, error) { + if v.keyFetcher != nil { + return v.keyFetcher(ctx) + } + return getPemKeys(v.keysLocation) +} + func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc { return func(token *jwt.Token) (interface{}, error) { // If keys are rotated, verify the keys prior to token validation @@ -131,13 +162,13 @@ func (v *Validator) refreshKeys(ctx context.Context) { v.lock.Lock() defer v.lock.Unlock() - refreshedKeys, err := getPemKeys(v.keysLocation) + refreshedKeys, err := v.fetchKeys(ctx) if err != nil { log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) return } - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC()) v.keys = refreshedKeys } @@ -155,13 +186,13 @@ func (v *Validator) forceRefreshKeys(ctx context.Context) bool { log.WithContext(ctx).Debugf("key not found in cache, forcing JWKS refresh") - refreshedKeys, err := getPemKeys(v.keysLocation) + refreshedKeys, err := v.fetchKeys(ctx) if err != nil { log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) return false } - log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.ExpiresInTime.UTC()) v.keys = refreshedKeys v.lastForcedRefresh = time.Now() return true @@ -203,7 +234,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To // stillValid returns true if the JSONWebKey still valid and have enough time to be used func (jwks *Jwks) stillValid() bool { - return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) + return !jwks.ExpiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.ExpiresInTime) } func getPemKeys(keysLocation string) (*Jwks, error) { @@ -227,7 +258,7 @@ func getPemKeys(keysLocation string) (*Jwks, error) { cacheControlHeader := resp.Header.Get("Cache-Control") expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) - jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) + jwks.ExpiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) return jwks, nil }