diff --git a/backend/internal/controller/oidc_controller.go b/backend/internal/controller/oidc_controller.go index 669fb0a7..8dd81a5a 100644 --- a/backend/internal/controller/oidc_controller.go +++ b/backend/internal/controller/oidc_controller.go @@ -99,6 +99,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) { c.Request.Context(), input, c.GetString("userID"), + c.GetString("authenticationMethod"), c.ClientIP(), c.Request.UserAgent(), ) @@ -808,7 +809,14 @@ func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) { ipAddress := c.ClientIP() userAgent := c.Request.UserAgent() - err := oc.oidcService.VerifyDeviceCode(c.Request.Context(), userCode, c.GetString("userID"), ipAddress, userAgent) + err := oc.oidcService.VerifyDeviceCode( + c.Request.Context(), + userCode, + c.GetString("userID"), + c.GetString("authenticationMethod"), + ipAddress, + userAgent) + if err != nil { _ = c.Error(err) return @@ -864,7 +872,13 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) { return } - preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, strings.Split(scopes, " ")) + preview, err := oc.oidcService.GetClientPreview( + c.Request.Context(), + clientID, + userID, + strings.Split(scopes, " "), + c.GetString("authenticationMethod")) + if err != nil { _ = c.Error(err) return diff --git a/backend/internal/middleware/auth_middleware.go b/backend/internal/middleware/auth_middleware.go index 3af9ce17..889ccdff 100644 --- a/backend/internal/middleware/auth_middleware.go +++ b/backend/internal/middleware/auth_middleware.go @@ -74,10 +74,11 @@ func (m *AuthMiddleware) WithApiKeyAuthDisabled() *AuthMiddleware { func (m *AuthMiddleware) Add() gin.HandlerFunc { return func(c *gin.Context) { - userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired) + userID, isAdmin, authenticationMethod, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired) if err == nil { c.Set("userID", userID) c.Set("userIsAdmin", isAdmin) + c.Set("authenticationMethod", authenticationMethod) if c.IsAborted() { return } diff --git a/backend/internal/middleware/auth_middleware_test.go b/backend/internal/middleware/auth_middleware_test.go index e3fb0ee6..cb39bcdf 100644 --- a/backend/internal/middleware/auth_middleware_test.go +++ b/backend/internal/middleware/auth_middleware_test.go @@ -44,7 +44,7 @@ func TestWithApiKeyAuthDisabled(t *testing.T) { authMiddleware := NewAuthMiddleware(apiKeyService, userService, jwtService) user := createUserForAuthMiddlewareTest(t, db) - jwtToken, err := jwtService.GenerateAccessToken(user) + jwtToken, err := jwtService.GenerateAccessToken(user, "") require.NoError(t, err) _, apiKeyToken, err := apiKeyService.CreateApiKey(t.Context(), user.ID, dto.ApiKeyCreateDto{ diff --git a/backend/internal/middleware/jwt_auth.go b/backend/internal/middleware/jwt_auth.go index 69d1728e..c311b537 100644 --- a/backend/internal/middleware/jwt_auth.go +++ b/backend/internal/middleware/jwt_auth.go @@ -20,7 +20,7 @@ func NewJwtAuthMiddleware(jwtService *service.JwtService, userService *service.U func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc { return func(c *gin.Context) { - userID, isAdmin, err := m.Verify(c, adminRequired) + userID, isAdmin, authenticationMethod, err := m.Verify(c, adminRequired) if err != nil { c.Abort() _ = c.Error(err) @@ -29,11 +29,12 @@ func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc { c.Set("userID", userID) c.Set("userIsAdmin", isAdmin) + c.Set("authenticationMethod", authenticationMethod) c.Next() } } -func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, err error) { +func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, authenticationMethod string, err error) { // Extract the token from the cookie accessToken, err := c.Cookie(cookie.AccessTokenCookieName) if err != nil { @@ -41,33 +42,37 @@ func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject var ok bool _, accessToken, ok = strings.Cut(c.GetHeader("Authorization"), " ") if !ok || accessToken == "" { - return "", false, &common.NotSignedInError{} + return "", false, "", &common.NotSignedInError{} } } token, err := m.jwtService.VerifyAccessToken(accessToken) if err != nil { - return "", false, &common.NotSignedInError{} + return "", false, "", &common.NotSignedInError{} + } + authenticationMethod, err = service.GetAuthenticationMethod(token) + if err != nil { + return "", false, "", &common.NotSignedInError{} } subject, ok := token.Subject() if !ok { _ = c.Error(&common.TokenInvalidError{}) - return + return "", false, "", &common.TokenInvalidError{} } user, err := m.userService.GetUser(c, subject) if err != nil { - return "", false, &common.NotSignedInError{} + return "", false, "", &common.NotSignedInError{} } if user.Disabled { - return "", false, &common.UserDisabledError{} + return "", false, "", &common.UserDisabledError{} } if adminRequired && !user.IsAdmin { - return "", false, &common.MissingPermissionError{} + return "", false, "", &common.MissingPermissionError{} } - return subject, isAdmin, nil + return subject, user.IsAdmin, authenticationMethod, nil } diff --git a/backend/internal/model/oidc.go b/backend/internal/model/oidc.go index 7f48d2c6..1a3909b7 100644 --- a/backend/internal/model/oidc.go +++ b/backend/internal/model/oidc.go @@ -33,6 +33,7 @@ type OidcAuthorizationCode struct { Code string Scope string + AuthenticationMethod string Nonce string CodeChallenge *string CodeChallengeMethodSha256 *bool @@ -77,9 +78,10 @@ func (c OidcClient) HasDarkLogo() bool { type OidcRefreshToken struct { Base - Token string - ExpiresAt datatype.DateTime - Scope string + Token string + ExpiresAt datatype.DateTime + Scope string + AuthenticationMethod string UserID string User User @@ -141,12 +143,13 @@ func (cu UrlList) Value() (driver.Value, error) { type OidcDeviceCode struct { Base - DeviceCode string - UserCode string - Scope string - Nonce string - ExpiresAt datatype.DateTime - IsAuthorized bool + DeviceCode string + UserCode string + Scope string + AuthenticationMethod string + Nonce string + ExpiresAt datatype.DateTime + IsAuthorized bool UserID *string User User diff --git a/backend/internal/service/e2etest_service.go b/backend/internal/service/e2etest_service.go index 6cc76b8a..362c849e 100644 --- a/backend/internal/service/e2etest_service.go +++ b/backend/internal/service/e2etest_service.go @@ -245,20 +245,22 @@ func (s *TestService) SeedDatabase(baseURL string) error { authCodes := []model.OidcAuthorizationCode{ { - Code: "auth-code", - Scope: "openid profile", - Nonce: "nonce", - ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), - UserID: users[0].ID, - ClientID: oidcClients[0].ID, + Code: "auth-code", + Scope: "openid profile", + Nonce: "nonce", + AuthenticationMethod: AuthenticationMethodPhishingResistant, + ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), + UserID: users[0].ID, + ClientID: oidcClients[0].ID, }, { - Code: "federated", - Scope: "openid profile", - Nonce: "nonce", - ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), - UserID: users[1].ID, - ClientID: oidcClients[3].ID, + Code: "federated", + Scope: "openid profile", + Nonce: "nonce", + AuthenticationMethod: AuthenticationMethodPhishingResistant, + ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)), + UserID: users[1].ID, + ClientID: oidcClients[3].ID, }, } for _, authCode := range authCodes { @@ -268,11 +270,12 @@ func (s *TestService) SeedDatabase(baseURL string) error { } refreshToken := model.OidcRefreshToken{ - Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"), - ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)), - Scope: "openid profile email", - UserID: users[0].ID, - ClientID: oidcClients[0].ID, + Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"), + AuthenticationMethod: AuthenticationMethodPhishingResistant, + ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)), + Scope: "openid profile email", + UserID: users[0].ID, + ClientID: oidcClients[0].ID, } if err := tx.Create(&refreshToken).Error; err != nil { return err diff --git a/backend/internal/service/jwt_service.go b/backend/internal/service/jwt_service.go index aa83254c..c3797a94 100644 --- a/backend/internal/service/jwt_service.go +++ b/backend/internal/service/jwt_service.go @@ -32,6 +32,15 @@ const ( // RefreshTokenClaim is the claim used for the refresh token's value RefreshTokenClaim = "rt" + // AuthenticationMethodsClaim is the claim used to identify how the user authenticated + AuthenticationMethodsClaim = "amr" + + // AuthenticationMethodPhishingResistant identifies phishing-resistant authentication, such as passkeys + AuthenticationMethodPhishingResistant = "phr" + + // AuthenticationMethodOneTimePassword identifies one-time password/code authentication + AuthenticationMethodOneTimePassword = "otp" + // OAuthAccessTokenJWTType identifies a JWT as an OAuth access token OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec @@ -187,7 +196,8 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error { return nil } -func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { +func (s *JwtService) GenerateAccessToken(user model.User, authenticationMethod string) (string, error) { + now := time.Now() token, err := jwt.NewBuilder(). Subject(user.ID). @@ -215,6 +225,11 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) { return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err) } + err = SetAuthenticationMethods(token, authenticationMethod) + if err != nil { + return "", fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err) + } + alg, _ := s.privateKey.Algorithm() signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey)) if err != nil { @@ -243,7 +258,7 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) { } // BuildIDToken creates an ID token with all claims -func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) { +func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (jwt.Token, error) { now := time.Now() token, err := jwt.NewBuilder(). Expiration(now.Add(1 * time.Hour)). @@ -265,6 +280,11 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err) } + err = SetAuthenticationMethods(token, authenticationMethod) + if err != nil { + return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err) + } + for k, v := range userClaims { err = token.Set(k, v) if err != nil { @@ -283,8 +303,8 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no } // GenerateIDToken creates and signs an ID token -func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) { - token, err := s.BuildIDToken(userClaims, clientID, nonce) +func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (string, error) { + token, err := s.BuildIDToken(userClaims, clientID, nonce, authenticationMethod) if err != nil { return "", err } @@ -332,7 +352,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) } // BuildOAuthAccessToken creates an OAuth access token with all claims -func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) { +func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string, authenticationMethod string) (jwt.Token, error) { now := time.Now() token, err := jwt.NewBuilder(). Subject(user.ID). @@ -355,12 +375,17 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err) } + err = SetAuthenticationMethods(token, authenticationMethod) + if err != nil { + return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err) + } + return token, nil } // GenerateOAuthAccessToken creates and signs an OAuth access token -func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) { - token, err := s.BuildOAuthAccessToken(user, clientID) +func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string, authenticationMethod string) (string, error) { + token, err := s.BuildOAuthAccessToken(user, clientID, authenticationMethod) if err != nil { return "", err } @@ -534,6 +559,27 @@ func GetIsAdmin(token jwt.Token) (bool, error) { return isAdmin, nil } +// GetAuthenticationMethod returns the first authentication method in the "amr" claim in the token +func GetAuthenticationMethod(token jwt.Token) (string, error) { + if !token.Has(AuthenticationMethodsClaim) { + return "", nil + } + var rawAuthenticationMethods []any + err := token.Get(AuthenticationMethodsClaim, &rawAuthenticationMethods) + if err != nil { + return "", fmt.Errorf("failed to get '%s' claim from token: %w", AuthenticationMethodsClaim, err) + } + + if len(rawAuthenticationMethods) == 0 { + return "", nil + } + authenticationMethod, ok := rawAuthenticationMethods[0].(string) + if !ok { + return "", fmt.Errorf("invalid '%s' claim in token: expected array of strings", AuthenticationMethodsClaim) + } + return authenticationMethod, nil +} + // SetTokenType sets the "type" claim in the token func SetTokenType(token jwt.Token, tokenType string) error { if tokenType == "" { @@ -551,6 +597,14 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error { return token.Set(IsAdminClaim, isAdmin) } +// SetAuthenticationMethods sets the authentication method references claim in the token +func SetAuthenticationMethods(token jwt.Token, authenticationMethod string) error { + if authenticationMethod == "" { + return nil + } + return token.Set(AuthenticationMethodsClaim, []string{authenticationMethod}) +} + // SetAudienceString sets the "aud" claim with a value that is a string, and not an array // This is permitted by RFC 7519, and it's done here for backwards-compatibility func SetAudienceString(token jwt.Token, audience string) error { diff --git a/backend/internal/service/jwt_service_test.go b/backend/internal/service/jwt_service_test.go index 3adba7e0..e7150085 100644 --- a/backend/internal/service/jwt_service_test.go +++ b/backend/internal/service/jwt_service_test.go @@ -174,6 +174,7 @@ func TestJwtService_Init(t *testing.T) { _ = assert.True(t, ok) && assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original") }) + } func TestJwtService_GetPublicJWK(t *testing.T) { @@ -308,7 +309,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { IsAdmin: false, } - tokenString, err := service.GenerateAccessToken(user) + tokenString, err := service.GenerateAccessToken(user, "") require.NoError(t, err, "Failed to generate access token") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -321,6 +322,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) { isAdmin, err := GetIsAdmin(claims) _ = assert.NoError(t, err, "Failed to get isAdmin claim") && assert.False(t, isAdmin, "isAdmin should be false") + authenticationMethod, err := GetAuthenticationMethod(claims) + _ = assert.NoError(t, err, "Failed to get amr claim") && + assert.Empty(t, authenticationMethod, "amr should be empty when not specified") audience, ok := claims.Audience() _ = assert.True(t, ok, "Audience not found in token") && assert.Equal(t, []string{service.envConfig.AppURL}, audience, "Audience should contain the app URL") @@ -344,7 +348,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { IsAdmin: true, } - tokenString, err := service.GenerateAccessToken(adminUser) + tokenString, err := service.GenerateAccessToken(adminUser, "") require.NoError(t, err, "Failed to generate access token") claims, err := service.VerifyAccessToken(tokenString) @@ -358,6 +362,24 @@ func TestGenerateVerifyAccessToken(t *testing.T) { assert.Equal(t, adminUser.ID, subject, "Token subject should match user ID") }) + t.Run("sets authentication method references claim when provided", func(t *testing.T) { + service, _, _ := setupJwtService(t, mockConfig) + + user := model.User{ + Base: model.Base{ID: "user-with-auth-method"}, + } + + tokenString, err := service.GenerateAccessToken(user, AuthenticationMethodPhishingResistant) + require.NoError(t, err, "Failed to generate access token") + + claims, err := service.VerifyAccessToken(tokenString) + require.NoError(t, err, "Failed to verify generated token") + + authenticationMethod, err := GetAuthenticationMethod(claims) + _ = assert.NoError(t, err, "Failed to get amr claim") && + assert.Equal(t, AuthenticationMethodPhishingResistant, authenticationMethod, "amr should match") + }) + t.Run("uses session duration from config", func(t *testing.T) { customMockConfig := NewTestAppConfigService(&model.AppConfig{ SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes @@ -368,7 +390,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { Base: model.Base{ID: "user456"}, } - tokenString, err := service.GenerateAccessToken(user) + tokenString, err := service.GenerateAccessToken(user, "") require.NoError(t, err, "Failed to generate access token") claims, err := service.VerifyAccessToken(tokenString) @@ -396,7 +418,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { IsAdmin: true, } - tokenString, err := service.GenerateAccessToken(user) + tokenString, err := service.GenerateAccessToken(user, "") require.NoError(t, err, "Failed to generate access token with Ed25519 key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -433,7 +455,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { IsAdmin: true, } - tokenString, err := service.GenerateAccessToken(user) + tokenString, err := service.GenerateAccessToken(user, "") require.NoError(t, err, "Failed to generate access token with ECDSA key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -470,7 +492,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) { IsAdmin: true, } - tokenString, err := service.GenerateAccessToken(user) + tokenString, err := service.GenerateAccessToken(user, "") require.NoError(t, err, "Failed to generate access token with RSA key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -508,7 +530,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "test-client-123" - tokenString, err := service.GenerateIDToken(userClaims, clientID, "") + tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "") require.NoError(t, err, "Failed to generate ID token") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -593,7 +615,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { const clientID = "test-client-456" nonce := "random-nonce-value" - tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce) + tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce, "") require.NoError(t, err, "Failed to generate ID token with nonce") publicKey, err := service.GetPublicJWK() @@ -614,7 +636,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { userClaims := map[string]any{ "sub": "user789", } - tokenString, err := service.GenerateIDToken(userClaims, "client-789", "") + tokenString, err := service.GenerateIDToken(userClaims, "client-789", "", "") require.NoError(t, err, "Failed to generate ID token") service.envConfig.AppURL = "https://wrong-issuer.com" @@ -640,7 +662,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "eddsa-client-123" - tokenString, err := service.GenerateIDToken(userClaims, clientID, "") + tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "") require.NoError(t, err, "Failed to generate ID token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -677,7 +699,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "ecdsa-client-123" - tokenString, err := service.GenerateIDToken(userClaims, clientID, "") + tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "") require.NoError(t, err, "Failed to generate ID token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -715,7 +737,7 @@ func TestGenerateVerifyIdToken(t *testing.T) { } const clientID = "rsa-client-123" - tokenString, err := service.GenerateIDToken(userClaims, clientID, "") + tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "") require.NoError(t, err, "Failed to generate ID token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -745,7 +767,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { } const clientID = "test-client-123" - tokenString, err := service.GenerateOAuthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "") require.NoError(t, err, "Failed to generate OAuth access token") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -772,6 +794,25 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour") }) + t.Run("sets authentication method references claim when provided", func(t *testing.T) { + service, _, _ := setupJwtService(t, mockConfig) + + user := model.User{ + Base: model.Base{ID: "oauth-amr-user"}, + } + const clientID = "test-client-amr" + + tokenString, err := service.GenerateOAuthAccessToken(user, clientID, AuthenticationMethodPhishingResistant) + require.NoError(t, err, "Failed to generate OAuth access token") + + claims, err := service.VerifyOAuthAccessToken(tokenString) + require.NoError(t, err, "Failed to verify generated OAuth access token") + + authenticationMethod, err := GetAuthenticationMethod(claims) + _ = assert.NoError(t, err, "Failed to get amr claim") && + assert.Equal(t, AuthenticationMethodPhishingResistant, authenticationMethod, "amr should match") + }) + t.Run("fails verification for expired token", func(t *testing.T) { service, _, _ := setupJwtService(t, mockConfig) @@ -805,7 +846,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { user := model.User{Base: model.Base{ID: "user789"}} const clientID = "test-client-789" - tokenString, err := service1.GenerateOAuthAccessToken(user, clientID) + tokenString, err := service1.GenerateOAuthAccessToken(user, clientID, "") require.NoError(t, err, "Failed to generate OAuth access token") _, err = service2.VerifyOAuthAccessToken(tokenString) @@ -828,7 +869,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { } const clientID = "eddsa-oauth-client" - tokenString, err := service.GenerateOAuthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "") require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -865,7 +906,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { } const clientID = "ecdsa-oauth-client" - tokenString, err := service.GenerateOAuthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "") require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") @@ -902,7 +943,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) { } const clientID = "rsa-oauth-client" - tokenString, err := service.GenerateOAuthAccessToken(user, clientID) + tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "") require.NoError(t, err, "Failed to generate OAuth access token with key") assert.NotEmpty(t, tokenString, "Token should not be empty") diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 4f295a6e..e91a616a 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -123,7 +123,7 @@ func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) { ) } -func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) { +func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID string, authenticationMethod string, ipAddress, userAgent string) (string, string, error) { tx := s.db.Begin() defer tx.Rollback() @@ -179,7 +179,7 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie } // Create the authorization code - code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx) + code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, authenticationMethod, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx) if err != nil { return "", "", err } @@ -312,17 +312,17 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O } // Explicitly use the input clientID for the audience claim to ensure consistency - idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce) + idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce, deviceAuth.AuthenticationMethod) if err != nil { return CreatedTokens{}, err } - refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx) + refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, deviceAuth.AuthenticationMethod, tx) if err != nil { return CreatedTokens{}, err } - accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID) + accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID, deviceAuth.AuthenticationMethod) if err != nil { return CreatedTokens{}, err } @@ -363,7 +363,7 @@ func (s *OidcService) createTokenFromClientCredentials(ctx context.Context, inpu audClaim = input.Resource } - accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim) + accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim, "") if err != nil { return CreatedTokens{}, err } @@ -411,18 +411,20 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu return CreatedTokens{}, err } - idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce) + authenticationMethod := authorizationCodeMetaData.AuthenticationMethod + + idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce, authenticationMethod) if err != nil { return CreatedTokens{}, err } // Generate a refresh token - refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx) + refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, authenticationMethod, tx) if err != nil { return CreatedTokens{}, err } - accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID) + accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID, authenticationMethod) if err != nil { return CreatedTokens{}, err } @@ -500,7 +502,8 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto } // Generate a new access token - accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID) + authenticationMethods := storedRefreshToken.AuthenticationMethod + accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods) if err != nil { return CreatedTokens{}, err } @@ -513,13 +516,13 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto // Generate a new ID token // There's no nonce here because we don't have one with the refresh token, but that's not required - idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "") + idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "", authenticationMethods) if err != nil { return CreatedTokens{}, err } // Generate a new refresh token and invalidate the old one - newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx) + newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, authenticationMethods, tx) if err != nil { return CreatedTokens{}, err } @@ -1140,7 +1143,7 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo return callbackURL, nil } -func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) { +func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) { randomString, err := utils.GenerateRandomAlphanumericString(32) if err != nil { return "", err @@ -1154,6 +1157,7 @@ func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID stri ClientID: clientID, UserID: userID, Scope: scope, + AuthenticationMethod: authenticationMethod, Nonce: nonce, CodeChallenge: &codeChallenge, CodeChallengeMethodSha256: &codeChallengeMethodSha256, @@ -1297,7 +1301,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O }, nil } -func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, ipAddress string, userAgent string) error { +func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, authenticationMethod string, ipAddress string, userAgent string) error { tx := s.db.Begin() defer func() { tx.Rollback() @@ -1346,6 +1350,7 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use } deviceAuth.UserID = &userID + deviceAuth.AuthenticationMethod = authenticationMethod deviceAuth.IsAuthorized = true err = tx. @@ -1549,7 +1554,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri return dtos, response, err } -func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) { +func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, tx *gorm.DB) (string, error) { refreshToken, err := utils.GenerateRandomAlphanumericString(40) if err != nil { return "", err @@ -1560,11 +1565,12 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u refreshTokenHash := utils.CreateSha256Hash(refreshToken) m := model.OidcRefreshToken{ - ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)), - Token: refreshTokenHash, - ClientID: clientID, - UserID: userID, - Scope: scope, + ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)), + Token: refreshTokenHash, + ClientID: clientID, + UserID: userID, + Scope: scope, + AuthenticationMethod: authenticationMethod, } err = tx. @@ -1780,7 +1786,7 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C return nil } -func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) { +func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string, authenticationMethod string) (*dto.OidcClientPreviewDto, error) { tx := s.db.Begin() defer func() { tx.Rollback() @@ -1816,12 +1822,12 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use return nil, err } - idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "") + idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "", authenticationMethod) if err != nil { return nil, err } - accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID) + accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID, authenticationMethod) if err != nil { return nil, err } diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 4374a430..78e6fac0 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -562,6 +562,46 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) { }) } +func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) { + mockConfig := NewTestAppConfigService(&model.AppConfig{ + SessionDuration: model.AppConfigVariable{Value: "60"}, + }) + jwtService, db, _ := setupJwtService(t, mockConfig) + service := &OidcService{ + db: db, + jwtService: jwtService, + } + authenticationMethod := AuthenticationMethodPhishingResistant + + t.Run("stores authentication method on authorization codes", func(t *testing.T) { + code, err := service.createAuthorizationCode( + t.Context(), + "amr-client", + "amr-user", + "openid profile", + authenticationMethod, + "", + "", + "", + db, + ) + require.NoError(t, err) + + var authorizationCode model.OidcAuthorizationCode + require.NoError(t, db.First(&authorizationCode, "code = ?", code).Error) + assert.Equal(t, authenticationMethod, authorizationCode.AuthenticationMethod) + }) + + t.Run("stores authentication methods on refresh tokens", func(t *testing.T) { + _, err := service.createRefreshToken(t.Context(), "amr-client", "amr-user", "openid profile", authenticationMethod, db) + require.NoError(t, err) + + var refreshToken model.OidcRefreshToken + require.NoError(t, db.First(&refreshToken, "client_id = ? AND user_id = ?", "amr-client", "amr-user").Error) + assert.Equal(t, authenticationMethod, refreshToken.AuthenticationMethod) + }) +} + func TestValidateCodeVerifier_Plain(t *testing.T) { require.False(t, validateCodeVerifier("", "", false)) require.False(t, validateCodeVerifier("", "", true)) diff --git a/backend/internal/service/one_time_access_service.go b/backend/internal/service/one_time_access_service.go index 1b84f498..54a399ac 100644 --- a/backend/internal/service/one_time_access_service.go +++ b/backend/internal/service/one_time_access_service.go @@ -197,7 +197,7 @@ func (s *OneTimeAccessService) ExchangeOneTimeAccessToken(ctx context.Context, t return model.User{}, "", &common.DeviceCodeInvalid{} } - accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User) + accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User, AuthenticationMethodOneTimePassword) if err != nil { return model.User{}, "", err } diff --git a/backend/internal/service/user_signup_service.go b/backend/internal/service/user_signup_service.go index 1c5974d1..222e1a01 100644 --- a/backend/internal/service/user_signup_service.go +++ b/backend/internal/service/user_signup_service.go @@ -87,7 +87,7 @@ func (s *UserSignUpService) SignUp(ctx context.Context, signupData dto.SignUpDto return model.User{}, "", err } - accessToken, err := s.jwtService.GenerateAccessToken(user) + accessToken, err := s.jwtService.GenerateAccessToken(user, "") if err != nil { return model.User{}, "", err } @@ -148,7 +148,7 @@ func (s *UserSignUpService) SignUpInitialAdmin(ctx context.Context, signUpData d return model.User{}, "", err } - token, err := s.jwtService.GenerateAccessToken(user) + token, err := s.jwtService.GenerateAccessToken(user, AuthenticationMethodOneTimePassword) if err != nil { return model.User{}, "", err } diff --git a/backend/internal/service/webauthn_service.go b/backend/internal/service/webauthn_service.go index 1aaa65e3..18e8ac29 100644 --- a/backend/internal/service/webauthn_service.go +++ b/backend/internal/service/webauthn_service.go @@ -266,7 +266,7 @@ func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, cre return model.User{}, "", &common.UserDisabledError{} } - token, err := s.jwtService.GenerateAccessToken(*user) + token, err := s.jwtService.GenerateAccessToken(*user, AuthenticationMethodPhishingResistant) if err != nil { return model.User{}, "", err } @@ -389,6 +389,14 @@ func (s *WebAuthnService) CreateReauthenticationTokenWithAccessToken(ctx context return "", errors.New("access token does not contain user ID") } + authenticationMethod, err := GetAuthenticationMethod(token) + if err != nil { + return "", err + } + if authenticationMethod != AuthenticationMethodPhishingResistant { + return "", &common.ReauthenticationRequiredError{} + } + // Check if token is issued less than a minute ago tokenExpiration, ok := token.IssuedAt() if !ok || time.Since(tokenExpiration) > time.Minute { diff --git a/backend/internal/service/webauthn_service_test.go b/backend/internal/service/webauthn_service_test.go new file mode 100644 index 00000000..79ab0572 --- /dev/null +++ b/backend/internal/service/webauthn_service_test.go @@ -0,0 +1,68 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pocket-id/pocket-id/backend/internal/common" + "github.com/pocket-id/pocket-id/backend/internal/model" +) + +func TestCreateReauthenticationTokenWithAccessToken(t *testing.T) { + mockConfig := NewTestAppConfigService(&model.AppConfig{ + SessionDuration: model.AppConfigVariable{Value: "60"}, + }) + + setupService := func(t *testing.T) (*WebAuthnService, model.User) { + t.Helper() + + jwtService, db, _ := setupJwtService(t, mockConfig) + user := model.User{ + Base: model.Base{ID: "reauth-user"}, + Username: "reauth-user", + } + require.NoError(t, db.Create(&user).Error) + + return &WebAuthnService{ + db: db, + jwtService: jwtService, + }, user + } + + t.Run("accepts a fresh access token from WebAuthn login", func(t *testing.T) { + service, user := setupService(t) + accessToken, err := service.jwtService.GenerateAccessToken(user, AuthenticationMethodPhishingResistant) + require.NoError(t, err) + + reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken) + + require.NoError(t, err) + assert.NotEmpty(t, reauthenticationToken) + }) + + t.Run("rejects a fresh access token from one-time access login", func(t *testing.T) { + service, user := setupService(t) + accessToken, err := service.jwtService.GenerateAccessToken(user, AuthenticationMethodOneTimePassword) + require.NoError(t, err) + + reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken) + + assert.Empty(t, reauthenticationToken) + require.Error(t, err) + assert.ErrorAs(t, err, new(*common.ReauthenticationRequiredError)) + }) + + t.Run("rejects a fresh access token without an authentication method", func(t *testing.T) { + service, user := setupService(t) + accessToken, err := service.jwtService.GenerateAccessToken(user, "") + require.NoError(t, err) + + reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken) + + assert.Empty(t, reauthenticationToken) + require.Error(t, err) + assert.ErrorAs(t, err, new(*common.ReauthenticationRequiredError)) + }) +} diff --git a/backend/resources/migrations/postgres/20260418120000_oidc_amr_claim.down.sql b/backend/resources/migrations/postgres/20260418120000_oidc_amr_claim.down.sql new file mode 100644 index 00000000..5c87aed2 --- /dev/null +++ b/backend/resources/migrations/postgres/20260418120000_oidc_amr_claim.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method; +ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method; +ALTER TABLE oidc_device_codes DROP COLUMN authentication_method; diff --git a/backend/resources/migrations/postgres/20260418120000_oidc_amr_claim.up.sql b/backend/resources/migrations/postgres/20260418120000_oidc_amr_claim.up.sql new file mode 100644 index 00000000..e4dcd5fc --- /dev/null +++ b/backend/resources/migrations/postgres/20260418120000_oidc_amr_claim.up.sql @@ -0,0 +1,6 @@ +ALTER TABLE oidc_authorization_codes + ADD COLUMN authentication_method TEXT NOT NULL DEFAULT ''; +ALTER TABLE oidc_refresh_tokens + ADD COLUMN authentication_method TEXT NOT NULL DEFAULT ''; +ALTER TABLE oidc_device_codes + ADD COLUMN authentication_method TEXT NOT NULL DEFAULT ''; diff --git a/backend/resources/migrations/sqlite/20260418120000_oidc_amr_claim.down.sql b/backend/resources/migrations/sqlite/20260418120000_oidc_amr_claim.down.sql new file mode 100644 index 00000000..5c87aed2 --- /dev/null +++ b/backend/resources/migrations/sqlite/20260418120000_oidc_amr_claim.down.sql @@ -0,0 +1,3 @@ +ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method; +ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method; +ALTER TABLE oidc_device_codes DROP COLUMN authentication_method; diff --git a/backend/resources/migrations/sqlite/20260418120000_oidc_amr_claim.up.sql b/backend/resources/migrations/sqlite/20260418120000_oidc_amr_claim.up.sql new file mode 100644 index 00000000..e4dcd5fc --- /dev/null +++ b/backend/resources/migrations/sqlite/20260418120000_oidc_amr_claim.up.sql @@ -0,0 +1,6 @@ +ALTER TABLE oidc_authorization_codes + ADD COLUMN authentication_method TEXT NOT NULL DEFAULT ''; +ALTER TABLE oidc_refresh_tokens + ADD COLUMN authentication_method TEXT NOT NULL DEFAULT ''; +ALTER TABLE oidc_device_codes + ADD COLUMN authentication_method TEXT NOT NULL DEFAULT ''; diff --git a/tests/resources/export/database.json b/tests/resources/export/database.json index 670d80de..8644dcba 100644 --- a/tests/resources/export/database.json +++ b/tests/resources/export/database.json @@ -1,6 +1,6 @@ { "provider": "sqlite", - "version": 20260109090200, + "version": 20260418120000, "tableOrder": ["users", "user_groups", "oidc_clients", "signup_tokens"], "tables": { "api_keys": [ @@ -42,6 +42,7 @@ "oidc_authorization_codes": [ { "client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055", + "authentication_method": "phr", "code": "auth-code", "code_challenge": null, "code_challenge_method_sha256": null, @@ -54,6 +55,7 @@ }, { "client_id": "c48232ff-ff65-45ed-ae96-7afa8a9b443b", + "authentication_method": "phr", "code": "federated", "code_challenge": null, "code_challenge_method_sha256": null, @@ -168,6 +170,7 @@ "oidc_refresh_tokens": [ { "client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055", + "authentication_method": "phr", "created_at": "2025-11-25T12:39:02Z", "expires_at": "2025-11-26T12:39:02Z", "id": "4928604e-e689-410c-9b25-5b9b6db9e46e",