feat: add auth method claim (amr) to tokens (#1433)

This commit is contained in:
Elias Schneider
2026-04-18 22:31:24 +02:00
committed by GitHub
parent c5a4ffa523
commit 5c4d7ff877
19 changed files with 355 additions and 91 deletions

View File

@@ -99,6 +99,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
c.Request.Context(),
input,
c.GetString("userID"),
c.GetString("authenticationMethod"),
c.ClientIP(),
c.Request.UserAgent(),
)
@@ -808,7 +809,14 @@ func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
ipAddress := c.ClientIP()
userAgent := c.Request.UserAgent()
err := oc.oidcService.VerifyDeviceCode(c.Request.Context(), userCode, c.GetString("userID"), ipAddress, userAgent)
err := oc.oidcService.VerifyDeviceCode(
c.Request.Context(),
userCode,
c.GetString("userID"),
c.GetString("authenticationMethod"),
ipAddress,
userAgent)
if err != nil {
_ = c.Error(err)
return
@@ -864,7 +872,13 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
return
}
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, strings.Split(scopes, " "))
preview, err := oc.oidcService.GetClientPreview(
c.Request.Context(),
clientID,
userID,
strings.Split(scopes, " "),
c.GetString("authenticationMethod"))
if err != nil {
_ = c.Error(err)
return

View File

@@ -74,10 +74,11 @@ func (m *AuthMiddleware) WithApiKeyAuthDisabled() *AuthMiddleware {
func (m *AuthMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) {
userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
userID, isAdmin, authenticationMethod, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
if err == nil {
c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin)
c.Set("authenticationMethod", authenticationMethod)
if c.IsAborted() {
return
}

View File

@@ -44,7 +44,7 @@ func TestWithApiKeyAuthDisabled(t *testing.T) {
authMiddleware := NewAuthMiddleware(apiKeyService, userService, jwtService)
user := createUserForAuthMiddlewareTest(t, db)
jwtToken, err := jwtService.GenerateAccessToken(user)
jwtToken, err := jwtService.GenerateAccessToken(user, "")
require.NoError(t, err)
_, apiKeyToken, err := apiKeyService.CreateApiKey(t.Context(), user.ID, dto.ApiKeyCreateDto{

View File

@@ -20,7 +20,7 @@ func NewJwtAuthMiddleware(jwtService *service.JwtService, userService *service.U
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
return func(c *gin.Context) {
userID, isAdmin, err := m.Verify(c, adminRequired)
userID, isAdmin, authenticationMethod, err := m.Verify(c, adminRequired)
if err != nil {
c.Abort()
_ = c.Error(err)
@@ -29,11 +29,12 @@ func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin)
c.Set("authenticationMethod", authenticationMethod)
c.Next()
}
}
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, err error) {
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, authenticationMethod string, err error) {
// Extract the token from the cookie
accessToken, err := c.Cookie(cookie.AccessTokenCookieName)
if err != nil {
@@ -41,33 +42,37 @@ func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject
var ok bool
_, accessToken, ok = strings.Cut(c.GetHeader("Authorization"), " ")
if !ok || accessToken == "" {
return "", false, &common.NotSignedInError{}
return "", false, "", &common.NotSignedInError{}
}
}
token, err := m.jwtService.VerifyAccessToken(accessToken)
if err != nil {
return "", false, &common.NotSignedInError{}
return "", false, "", &common.NotSignedInError{}
}
authenticationMethod, err = service.GetAuthenticationMethod(token)
if err != nil {
return "", false, "", &common.NotSignedInError{}
}
subject, ok := token.Subject()
if !ok {
_ = c.Error(&common.TokenInvalidError{})
return
return "", false, "", &common.TokenInvalidError{}
}
user, err := m.userService.GetUser(c, subject)
if err != nil {
return "", false, &common.NotSignedInError{}
return "", false, "", &common.NotSignedInError{}
}
if user.Disabled {
return "", false, &common.UserDisabledError{}
return "", false, "", &common.UserDisabledError{}
}
if adminRequired && !user.IsAdmin {
return "", false, &common.MissingPermissionError{}
return "", false, "", &common.MissingPermissionError{}
}
return subject, isAdmin, nil
return subject, user.IsAdmin, authenticationMethod, nil
}

View File

@@ -33,6 +33,7 @@ type OidcAuthorizationCode struct {
Code string
Scope string
AuthenticationMethod string
Nonce string
CodeChallenge *string
CodeChallengeMethodSha256 *bool
@@ -77,9 +78,10 @@ func (c OidcClient) HasDarkLogo() bool {
type OidcRefreshToken struct {
Base
Token string
ExpiresAt datatype.DateTime
Scope string
Token string
ExpiresAt datatype.DateTime
Scope string
AuthenticationMethod string
UserID string
User User
@@ -141,12 +143,13 @@ func (cu UrlList) Value() (driver.Value, error) {
type OidcDeviceCode struct {
Base
DeviceCode string
UserCode string
Scope string
Nonce string
ExpiresAt datatype.DateTime
IsAuthorized bool
DeviceCode string
UserCode string
Scope string
AuthenticationMethod string
Nonce string
ExpiresAt datatype.DateTime
IsAuthorized bool
UserID *string
User User

View File

@@ -245,20 +245,22 @@ func (s *TestService) SeedDatabase(baseURL string) error {
authCodes := []model.OidcAuthorizationCode{
{
Code: "auth-code",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
Code: "auth-code",
Scope: "openid profile",
Nonce: "nonce",
AuthenticationMethod: AuthenticationMethodPhishingResistant,
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
},
{
Code: "federated",
Scope: "openid profile",
Nonce: "nonce",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[1].ID,
ClientID: oidcClients[3].ID,
Code: "federated",
Scope: "openid profile",
Nonce: "nonce",
AuthenticationMethod: AuthenticationMethodPhishingResistant,
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
UserID: users[1].ID,
ClientID: oidcClients[3].ID,
},
}
for _, authCode := range authCodes {
@@ -268,11 +270,12 @@ func (s *TestService) SeedDatabase(baseURL string) error {
}
refreshToken := model.OidcRefreshToken{
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
AuthenticationMethod: AuthenticationMethodPhishingResistant,
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
}
if err := tx.Create(&refreshToken).Error; err != nil {
return err

View File

@@ -32,6 +32,15 @@ const (
// RefreshTokenClaim is the claim used for the refresh token's value
RefreshTokenClaim = "rt"
// AuthenticationMethodsClaim is the claim used to identify how the user authenticated
AuthenticationMethodsClaim = "amr"
// AuthenticationMethodPhishingResistant identifies phishing-resistant authentication, such as passkeys
AuthenticationMethodPhishingResistant = "phr"
// AuthenticationMethodOneTimePassword identifies one-time password/code authentication
AuthenticationMethodOneTimePassword = "otp"
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
@@ -187,7 +196,8 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error {
return nil
}
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
func (s *JwtService) GenerateAccessToken(user model.User, authenticationMethod string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
@@ -215,6 +225,11 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
}
err = SetAuthenticationMethods(token, authenticationMethod)
if err != nil {
return "", fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
@@ -243,7 +258,7 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
}
// BuildIDToken creates an ID token with all claims
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) {
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
@@ -265,6 +280,11 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
err = SetAuthenticationMethods(token, authenticationMethod)
if err != nil {
return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
}
for k, v := range userClaims {
err = token.Set(k, v)
if err != nil {
@@ -283,8 +303,8 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
}
// GenerateIDToken creates and signs an ID token
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
token, err := s.BuildIDToken(userClaims, clientID, nonce)
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (string, error) {
token, err := s.BuildIDToken(userClaims, clientID, nonce, authenticationMethod)
if err != nil {
return "", err
}
@@ -332,7 +352,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
}
// BuildOAuthAccessToken creates an OAuth access token with all claims
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string, authenticationMethod string) (jwt.Token, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
@@ -355,12 +375,17 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
err = SetAuthenticationMethods(token, authenticationMethod)
if err != nil {
return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
}
return token, nil
}
// GenerateOAuthAccessToken creates and signs an OAuth access token
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
token, err := s.BuildOAuthAccessToken(user, clientID)
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string, authenticationMethod string) (string, error) {
token, err := s.BuildOAuthAccessToken(user, clientID, authenticationMethod)
if err != nil {
return "", err
}
@@ -534,6 +559,27 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
return isAdmin, nil
}
// GetAuthenticationMethod returns the first authentication method in the "amr" claim in the token
func GetAuthenticationMethod(token jwt.Token) (string, error) {
if !token.Has(AuthenticationMethodsClaim) {
return "", nil
}
var rawAuthenticationMethods []any
err := token.Get(AuthenticationMethodsClaim, &rawAuthenticationMethods)
if err != nil {
return "", fmt.Errorf("failed to get '%s' claim from token: %w", AuthenticationMethodsClaim, err)
}
if len(rawAuthenticationMethods) == 0 {
return "", nil
}
authenticationMethod, ok := rawAuthenticationMethods[0].(string)
if !ok {
return "", fmt.Errorf("invalid '%s' claim in token: expected array of strings", AuthenticationMethodsClaim)
}
return authenticationMethod, nil
}
// SetTokenType sets the "type" claim in the token
func SetTokenType(token jwt.Token, tokenType string) error {
if tokenType == "" {
@@ -551,6 +597,14 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error {
return token.Set(IsAdminClaim, isAdmin)
}
// SetAuthenticationMethods sets the authentication method references claim in the token
func SetAuthenticationMethods(token jwt.Token, authenticationMethod string) error {
if authenticationMethod == "" {
return nil
}
return token.Set(AuthenticationMethodsClaim, []string{authenticationMethod})
}
// SetAudienceString sets the "aud" claim with a value that is a string, and not an array
// This is permitted by RFC 7519, and it's done here for backwards-compatibility
func SetAudienceString(token jwt.Token, audience string) error {

View File

@@ -174,6 +174,7 @@ func TestJwtService_Init(t *testing.T) {
_ = assert.True(t, ok) &&
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
})
}
func TestJwtService_GetPublicJWK(t *testing.T) {
@@ -308,7 +309,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
IsAdmin: false,
}
tokenString, err := service.GenerateAccessToken(user)
tokenString, err := service.GenerateAccessToken(user, "")
require.NoError(t, err, "Failed to generate access token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -321,6 +322,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
isAdmin, err := GetIsAdmin(claims)
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
assert.False(t, isAdmin, "isAdmin should be false")
authenticationMethod, err := GetAuthenticationMethod(claims)
_ = assert.NoError(t, err, "Failed to get amr claim") &&
assert.Empty(t, authenticationMethod, "amr should be empty when not specified")
audience, ok := claims.Audience()
_ = assert.True(t, ok, "Audience not found in token") &&
assert.Equal(t, []string{service.envConfig.AppURL}, audience, "Audience should contain the app URL")
@@ -344,7 +348,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
IsAdmin: true,
}
tokenString, err := service.GenerateAccessToken(adminUser)
tokenString, err := service.GenerateAccessToken(adminUser, "")
require.NoError(t, err, "Failed to generate access token")
claims, err := service.VerifyAccessToken(tokenString)
@@ -358,6 +362,24 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
assert.Equal(t, adminUser.ID, subject, "Token subject should match user ID")
})
t.Run("sets authentication method references claim when provided", func(t *testing.T) {
service, _, _ := setupJwtService(t, mockConfig)
user := model.User{
Base: model.Base{ID: "user-with-auth-method"},
}
tokenString, err := service.GenerateAccessToken(user, AuthenticationMethodPhishingResistant)
require.NoError(t, err, "Failed to generate access token")
claims, err := service.VerifyAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated token")
authenticationMethod, err := GetAuthenticationMethod(claims)
_ = assert.NoError(t, err, "Failed to get amr claim") &&
assert.Equal(t, AuthenticationMethodPhishingResistant, authenticationMethod, "amr should match")
})
t.Run("uses session duration from config", func(t *testing.T) {
customMockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
@@ -368,7 +390,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
Base: model.Base{ID: "user456"},
}
tokenString, err := service.GenerateAccessToken(user)
tokenString, err := service.GenerateAccessToken(user, "")
require.NoError(t, err, "Failed to generate access token")
claims, err := service.VerifyAccessToken(tokenString)
@@ -396,7 +418,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
IsAdmin: true,
}
tokenString, err := service.GenerateAccessToken(user)
tokenString, err := service.GenerateAccessToken(user, "")
require.NoError(t, err, "Failed to generate access token with Ed25519 key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -433,7 +455,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
IsAdmin: true,
}
tokenString, err := service.GenerateAccessToken(user)
tokenString, err := service.GenerateAccessToken(user, "")
require.NoError(t, err, "Failed to generate access token with ECDSA key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -470,7 +492,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
IsAdmin: true,
}
tokenString, err := service.GenerateAccessToken(user)
tokenString, err := service.GenerateAccessToken(user, "")
require.NoError(t, err, "Failed to generate access token with RSA key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -508,7 +530,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "test-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -593,7 +615,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
const clientID = "test-client-456"
nonce := "random-nonce-value"
tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce)
tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce, "")
require.NoError(t, err, "Failed to generate ID token with nonce")
publicKey, err := service.GetPublicJWK()
@@ -614,7 +636,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
userClaims := map[string]any{
"sub": "user789",
}
tokenString, err := service.GenerateIDToken(userClaims, "client-789", "")
tokenString, err := service.GenerateIDToken(userClaims, "client-789", "", "")
require.NoError(t, err, "Failed to generate ID token")
service.envConfig.AppURL = "https://wrong-issuer.com"
@@ -640,7 +662,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "eddsa-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -677,7 +699,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "ecdsa-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -715,7 +737,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "rsa-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -745,7 +767,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
}
const clientID = "test-client-123"
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
require.NoError(t, err, "Failed to generate OAuth access token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -772,6 +794,25 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
})
t.Run("sets authentication method references claim when provided", func(t *testing.T) {
service, _, _ := setupJwtService(t, mockConfig)
user := model.User{
Base: model.Base{ID: "oauth-amr-user"},
}
const clientID = "test-client-amr"
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, AuthenticationMethodPhishingResistant)
require.NoError(t, err, "Failed to generate OAuth access token")
claims, err := service.VerifyOAuthAccessToken(tokenString)
require.NoError(t, err, "Failed to verify generated OAuth access token")
authenticationMethod, err := GetAuthenticationMethod(claims)
_ = assert.NoError(t, err, "Failed to get amr claim") &&
assert.Equal(t, AuthenticationMethodPhishingResistant, authenticationMethod, "amr should match")
})
t.Run("fails verification for expired token", func(t *testing.T) {
service, _, _ := setupJwtService(t, mockConfig)
@@ -805,7 +846,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
user := model.User{Base: model.Base{ID: "user789"}}
const clientID = "test-client-789"
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID, "")
require.NoError(t, err, "Failed to generate OAuth access token")
_, err = service2.VerifyOAuthAccessToken(tokenString)
@@ -828,7 +869,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
}
const clientID = "eddsa-oauth-client"
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -865,7 +906,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
}
const clientID = "ecdsa-oauth-client"
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -902,7 +943,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
}
const clientID = "rsa-oauth-client"
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
require.NoError(t, err, "Failed to generate OAuth access token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")

View File

@@ -123,7 +123,7 @@ func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
)
}
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID string, authenticationMethod string, ipAddress, userAgent string) (string, string, error) {
tx := s.db.Begin()
defer tx.Rollback()
@@ -179,7 +179,7 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie
}
// Create the authorization code
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, authenticationMethod, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
if err != nil {
return "", "", err
}
@@ -312,17 +312,17 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
}
// Explicitly use the input clientID for the audience claim to ensure consistency
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce)
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce, deviceAuth.AuthenticationMethod)
if err != nil {
return CreatedTokens{}, err
}
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, deviceAuth.AuthenticationMethod, tx)
if err != nil {
return CreatedTokens{}, err
}
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID, deviceAuth.AuthenticationMethod)
if err != nil {
return CreatedTokens{}, err
}
@@ -363,7 +363,7 @@ func (s *OidcService) createTokenFromClientCredentials(ctx context.Context, inpu
audClaim = input.Resource
}
accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim, "")
if err != nil {
return CreatedTokens{}, err
}
@@ -411,18 +411,20 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
return CreatedTokens{}, err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce)
authenticationMethod := authorizationCodeMetaData.AuthenticationMethod
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce, authenticationMethod)
if err != nil {
return CreatedTokens{}, err
}
// Generate a refresh token
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, authenticationMethod, tx)
if err != nil {
return CreatedTokens{}, err
}
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID, authenticationMethod)
if err != nil {
return CreatedTokens{}, err
}
@@ -500,7 +502,8 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
}
// Generate a new access token
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
authenticationMethods := storedRefreshToken.AuthenticationMethod
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods)
if err != nil {
return CreatedTokens{}, err
}
@@ -513,13 +516,13 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
// Generate a new ID token
// There's no nonce here because we don't have one with the refresh token, but that's not required
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "", authenticationMethods)
if err != nil {
return CreatedTokens{}, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, authenticationMethods, tx)
if err != nil {
return CreatedTokens{}, err
}
@@ -1140,7 +1143,7 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
return callbackURL, nil
}
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
@@ -1154,6 +1157,7 @@ func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID stri
ClientID: clientID,
UserID: userID,
Scope: scope,
AuthenticationMethod: authenticationMethod,
Nonce: nonce,
CodeChallenge: &codeChallenge,
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
@@ -1297,7 +1301,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
}, nil
}
func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, ipAddress string, userAgent string) error {
func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, authenticationMethod string, ipAddress string, userAgent string) error {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -1346,6 +1350,7 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
}
deviceAuth.UserID = &userID
deviceAuth.AuthenticationMethod = authenticationMethod
deviceAuth.IsAuthorized = true
err = tx.
@@ -1549,7 +1554,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
return dtos, response, err
}
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, tx *gorm.DB) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
@@ -1560,11 +1565,12 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
Token: refreshTokenHash,
ClientID: clientID,
UserID: userID,
Scope: scope,
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
Token: refreshTokenHash,
ClientID: clientID,
UserID: userID,
Scope: scope,
AuthenticationMethod: authenticationMethod,
}
err = tx.
@@ -1780,7 +1786,7 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
return nil
}
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string, authenticationMethod string) (*dto.OidcClientPreviewDto, error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
@@ -1816,12 +1822,12 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err
}
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "", authenticationMethod)
if err != nil {
return nil, err
}
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID, authenticationMethod)
if err != nil {
return nil, err
}

View File

@@ -562,6 +562,46 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
})
}
func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) {
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"},
})
jwtService, db, _ := setupJwtService(t, mockConfig)
service := &OidcService{
db: db,
jwtService: jwtService,
}
authenticationMethod := AuthenticationMethodPhishingResistant
t.Run("stores authentication method on authorization codes", func(t *testing.T) {
code, err := service.createAuthorizationCode(
t.Context(),
"amr-client",
"amr-user",
"openid profile",
authenticationMethod,
"",
"",
"",
db,
)
require.NoError(t, err)
var authorizationCode model.OidcAuthorizationCode
require.NoError(t, db.First(&authorizationCode, "code = ?", code).Error)
assert.Equal(t, authenticationMethod, authorizationCode.AuthenticationMethod)
})
t.Run("stores authentication methods on refresh tokens", func(t *testing.T) {
_, err := service.createRefreshToken(t.Context(), "amr-client", "amr-user", "openid profile", authenticationMethod, db)
require.NoError(t, err)
var refreshToken model.OidcRefreshToken
require.NoError(t, db.First(&refreshToken, "client_id = ? AND user_id = ?", "amr-client", "amr-user").Error)
assert.Equal(t, authenticationMethod, refreshToken.AuthenticationMethod)
})
}
func TestValidateCodeVerifier_Plain(t *testing.T) {
require.False(t, validateCodeVerifier("", "", false))
require.False(t, validateCodeVerifier("", "", true))

View File

@@ -197,7 +197,7 @@ func (s *OneTimeAccessService) ExchangeOneTimeAccessToken(ctx context.Context, t
return model.User{}, "", &common.DeviceCodeInvalid{}
}
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User, AuthenticationMethodOneTimePassword)
if err != nil {
return model.User{}, "", err
}

View File

@@ -87,7 +87,7 @@ func (s *UserSignUpService) SignUp(ctx context.Context, signupData dto.SignUpDto
return model.User{}, "", err
}
accessToken, err := s.jwtService.GenerateAccessToken(user)
accessToken, err := s.jwtService.GenerateAccessToken(user, "")
if err != nil {
return model.User{}, "", err
}
@@ -148,7 +148,7 @@ func (s *UserSignUpService) SignUpInitialAdmin(ctx context.Context, signUpData d
return model.User{}, "", err
}
token, err := s.jwtService.GenerateAccessToken(user)
token, err := s.jwtService.GenerateAccessToken(user, AuthenticationMethodOneTimePassword)
if err != nil {
return model.User{}, "", err
}

View File

@@ -266,7 +266,7 @@ func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, cre
return model.User{}, "", &common.UserDisabledError{}
}
token, err := s.jwtService.GenerateAccessToken(*user)
token, err := s.jwtService.GenerateAccessToken(*user, AuthenticationMethodPhishingResistant)
if err != nil {
return model.User{}, "", err
}
@@ -389,6 +389,14 @@ func (s *WebAuthnService) CreateReauthenticationTokenWithAccessToken(ctx context
return "", errors.New("access token does not contain user ID")
}
authenticationMethod, err := GetAuthenticationMethod(token)
if err != nil {
return "", err
}
if authenticationMethod != AuthenticationMethodPhishingResistant {
return "", &common.ReauthenticationRequiredError{}
}
// Check if token is issued less than a minute ago
tokenExpiration, ok := token.IssuedAt()
if !ok || time.Since(tokenExpiration) > time.Minute {

View File

@@ -0,0 +1,68 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
)
func TestCreateReauthenticationTokenWithAccessToken(t *testing.T) {
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"},
})
setupService := func(t *testing.T) (*WebAuthnService, model.User) {
t.Helper()
jwtService, db, _ := setupJwtService(t, mockConfig)
user := model.User{
Base: model.Base{ID: "reauth-user"},
Username: "reauth-user",
}
require.NoError(t, db.Create(&user).Error)
return &WebAuthnService{
db: db,
jwtService: jwtService,
}, user
}
t.Run("accepts a fresh access token from WebAuthn login", func(t *testing.T) {
service, user := setupService(t)
accessToken, err := service.jwtService.GenerateAccessToken(user, AuthenticationMethodPhishingResistant)
require.NoError(t, err)
reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken)
require.NoError(t, err)
assert.NotEmpty(t, reauthenticationToken)
})
t.Run("rejects a fresh access token from one-time access login", func(t *testing.T) {
service, user := setupService(t)
accessToken, err := service.jwtService.GenerateAccessToken(user, AuthenticationMethodOneTimePassword)
require.NoError(t, err)
reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken)
assert.Empty(t, reauthenticationToken)
require.Error(t, err)
assert.ErrorAs(t, err, new(*common.ReauthenticationRequiredError))
})
t.Run("rejects a fresh access token without an authentication method", func(t *testing.T) {
service, user := setupService(t)
accessToken, err := service.jwtService.GenerateAccessToken(user, "")
require.NoError(t, err)
reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken)
assert.Empty(t, reauthenticationToken)
require.Error(t, err)
assert.ErrorAs(t, err, new(*common.ReauthenticationRequiredError))
})
}

View File

@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method;
ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method;
ALTER TABLE oidc_device_codes DROP COLUMN authentication_method;

View File

@@ -0,0 +1,6 @@
ALTER TABLE oidc_authorization_codes
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
ALTER TABLE oidc_refresh_tokens
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
ALTER TABLE oidc_device_codes
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';

View File

@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method;
ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method;
ALTER TABLE oidc_device_codes DROP COLUMN authentication_method;

View File

@@ -0,0 +1,6 @@
ALTER TABLE oidc_authorization_codes
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
ALTER TABLE oidc_refresh_tokens
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
ALTER TABLE oidc_device_codes
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';

View File

@@ -1,6 +1,6 @@
{
"provider": "sqlite",
"version": 20260109090200,
"version": 20260418120000,
"tableOrder": ["users", "user_groups", "oidc_clients", "signup_tokens"],
"tables": {
"api_keys": [
@@ -42,6 +42,7 @@
"oidc_authorization_codes": [
{
"client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055",
"authentication_method": "phr",
"code": "auth-code",
"code_challenge": null,
"code_challenge_method_sha256": null,
@@ -54,6 +55,7 @@
},
{
"client_id": "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
"authentication_method": "phr",
"code": "federated",
"code_challenge": null,
"code_challenge_method_sha256": null,
@@ -168,6 +170,7 @@
"oidc_refresh_tokens": [
{
"client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055",
"authentication_method": "phr",
"created_at": "2025-11-25T12:39:02Z",
"expires_at": "2025-11-26T12:39:02Z",
"id": "4928604e-e689-410c-9b25-5b9b6db9e46e",