mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-12 07:59:52 +00:00
feat: add auth method claim (amr) to tokens (#1433)
This commit is contained in:
@@ -99,6 +99,7 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
c.Request.Context(),
|
||||
input,
|
||||
c.GetString("userID"),
|
||||
c.GetString("authenticationMethod"),
|
||||
c.ClientIP(),
|
||||
c.Request.UserAgent(),
|
||||
)
|
||||
@@ -808,7 +809,14 @@ func (oc *OidcController) verifyDeviceCodeHandler(c *gin.Context) {
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
err := oc.oidcService.VerifyDeviceCode(c.Request.Context(), userCode, c.GetString("userID"), ipAddress, userAgent)
|
||||
err := oc.oidcService.VerifyDeviceCode(
|
||||
c.Request.Context(),
|
||||
userCode,
|
||||
c.GetString("userID"),
|
||||
c.GetString("authenticationMethod"),
|
||||
ipAddress,
|
||||
userAgent)
|
||||
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
@@ -864,7 +872,13 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, strings.Split(scopes, " "))
|
||||
preview, err := oc.oidcService.GetClientPreview(
|
||||
c.Request.Context(),
|
||||
clientID,
|
||||
userID,
|
||||
strings.Split(scopes, " "),
|
||||
c.GetString("authenticationMethod"))
|
||||
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
|
||||
@@ -74,10 +74,11 @@ func (m *AuthMiddleware) WithApiKeyAuthDisabled() *AuthMiddleware {
|
||||
|
||||
func (m *AuthMiddleware) Add() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
|
||||
userID, isAdmin, authenticationMethod, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
|
||||
if err == nil {
|
||||
c.Set("userID", userID)
|
||||
c.Set("userIsAdmin", isAdmin)
|
||||
c.Set("authenticationMethod", authenticationMethod)
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestWithApiKeyAuthDisabled(t *testing.T) {
|
||||
authMiddleware := NewAuthMiddleware(apiKeyService, userService, jwtService)
|
||||
|
||||
user := createUserForAuthMiddlewareTest(t, db)
|
||||
jwtToken, err := jwtService.GenerateAccessToken(user)
|
||||
jwtToken, err := jwtService.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, apiKeyToken, err := apiKeyService.CreateApiKey(t.Context(), user.ID, dto.ApiKeyCreateDto{
|
||||
|
||||
@@ -20,7 +20,7 @@ func NewJwtAuthMiddleware(jwtService *service.JwtService, userService *service.U
|
||||
|
||||
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
userID, isAdmin, err := m.Verify(c, adminRequired)
|
||||
userID, isAdmin, authenticationMethod, err := m.Verify(c, adminRequired)
|
||||
if err != nil {
|
||||
c.Abort()
|
||||
_ = c.Error(err)
|
||||
@@ -29,11 +29,12 @@ func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
||||
|
||||
c.Set("userID", userID)
|
||||
c.Set("userIsAdmin", isAdmin)
|
||||
c.Set("authenticationMethod", authenticationMethod)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, err error) {
|
||||
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, authenticationMethod string, err error) {
|
||||
// Extract the token from the cookie
|
||||
accessToken, err := c.Cookie(cookie.AccessTokenCookieName)
|
||||
if err != nil {
|
||||
@@ -41,33 +42,37 @@ func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject
|
||||
var ok bool
|
||||
_, accessToken, ok = strings.Cut(c.GetHeader("Authorization"), " ")
|
||||
if !ok || accessToken == "" {
|
||||
return "", false, &common.NotSignedInError{}
|
||||
return "", false, "", &common.NotSignedInError{}
|
||||
}
|
||||
}
|
||||
|
||||
token, err := m.jwtService.VerifyAccessToken(accessToken)
|
||||
if err != nil {
|
||||
return "", false, &common.NotSignedInError{}
|
||||
return "", false, "", &common.NotSignedInError{}
|
||||
}
|
||||
authenticationMethod, err = service.GetAuthenticationMethod(token)
|
||||
if err != nil {
|
||||
return "", false, "", &common.NotSignedInError{}
|
||||
}
|
||||
|
||||
subject, ok := token.Subject()
|
||||
if !ok {
|
||||
_ = c.Error(&common.TokenInvalidError{})
|
||||
return
|
||||
return "", false, "", &common.TokenInvalidError{}
|
||||
}
|
||||
|
||||
user, err := m.userService.GetUser(c, subject)
|
||||
if err != nil {
|
||||
return "", false, &common.NotSignedInError{}
|
||||
return "", false, "", &common.NotSignedInError{}
|
||||
}
|
||||
|
||||
if user.Disabled {
|
||||
return "", false, &common.UserDisabledError{}
|
||||
return "", false, "", &common.UserDisabledError{}
|
||||
}
|
||||
|
||||
if adminRequired && !user.IsAdmin {
|
||||
return "", false, &common.MissingPermissionError{}
|
||||
return "", false, "", &common.MissingPermissionError{}
|
||||
}
|
||||
|
||||
return subject, isAdmin, nil
|
||||
return subject, user.IsAdmin, authenticationMethod, nil
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ type OidcAuthorizationCode struct {
|
||||
|
||||
Code string
|
||||
Scope string
|
||||
AuthenticationMethod string
|
||||
Nonce string
|
||||
CodeChallenge *string
|
||||
CodeChallengeMethodSha256 *bool
|
||||
@@ -77,9 +78,10 @@ func (c OidcClient) HasDarkLogo() bool {
|
||||
type OidcRefreshToken struct {
|
||||
Base
|
||||
|
||||
Token string
|
||||
ExpiresAt datatype.DateTime
|
||||
Scope string
|
||||
Token string
|
||||
ExpiresAt datatype.DateTime
|
||||
Scope string
|
||||
AuthenticationMethod string
|
||||
|
||||
UserID string
|
||||
User User
|
||||
@@ -141,12 +143,13 @@ func (cu UrlList) Value() (driver.Value, error) {
|
||||
|
||||
type OidcDeviceCode struct {
|
||||
Base
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Scope string
|
||||
Nonce string
|
||||
ExpiresAt datatype.DateTime
|
||||
IsAuthorized bool
|
||||
DeviceCode string
|
||||
UserCode string
|
||||
Scope string
|
||||
AuthenticationMethod string
|
||||
Nonce string
|
||||
ExpiresAt datatype.DateTime
|
||||
IsAuthorized bool
|
||||
|
||||
UserID *string
|
||||
User User
|
||||
|
||||
@@ -245,20 +245,22 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
|
||||
authCodes := []model.OidcAuthorizationCode{
|
||||
{
|
||||
Code: "auth-code",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
Code: "auth-code",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
AuthenticationMethod: AuthenticationMethodPhishingResistant,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
},
|
||||
{
|
||||
Code: "federated",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[3].ID,
|
||||
Code: "federated",
|
||||
Scope: "openid profile",
|
||||
Nonce: "nonce",
|
||||
AuthenticationMethod: AuthenticationMethodPhishingResistant,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[3].ID,
|
||||
},
|
||||
}
|
||||
for _, authCode := range authCodes {
|
||||
@@ -268,11 +270,12 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
}
|
||||
|
||||
refreshToken := model.OidcRefreshToken{
|
||||
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
||||
Scope: "openid profile email",
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
|
||||
AuthenticationMethod: AuthenticationMethodPhishingResistant,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
|
||||
Scope: "openid profile email",
|
||||
UserID: users[0].ID,
|
||||
ClientID: oidcClients[0].ID,
|
||||
}
|
||||
if err := tx.Create(&refreshToken).Error; err != nil {
|
||||
return err
|
||||
|
||||
@@ -32,6 +32,15 @@ const (
|
||||
// RefreshTokenClaim is the claim used for the refresh token's value
|
||||
RefreshTokenClaim = "rt"
|
||||
|
||||
// AuthenticationMethodsClaim is the claim used to identify how the user authenticated
|
||||
AuthenticationMethodsClaim = "amr"
|
||||
|
||||
// AuthenticationMethodPhishingResistant identifies phishing-resistant authentication, such as passkeys
|
||||
AuthenticationMethodPhishingResistant = "phr"
|
||||
|
||||
// AuthenticationMethodOneTimePassword identifies one-time password/code authentication
|
||||
AuthenticationMethodOneTimePassword = "otp"
|
||||
|
||||
// OAuthAccessTokenJWTType identifies a JWT as an OAuth access token
|
||||
OAuthAccessTokenJWTType = "oauth-access-token" //nolint:gosec
|
||||
|
||||
@@ -187,7 +196,8 @@ func (s *JwtService) SetKey(privateKey jwk.Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
func (s *JwtService) GenerateAccessToken(user model.User, authenticationMethod string) (string, error) {
|
||||
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
@@ -215,6 +225,11 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAuthenticationMethods(token, authenticationMethod)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
@@ -243,7 +258,7 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||
}
|
||||
|
||||
// BuildIDToken creates an ID token with all claims
|
||||
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) {
|
||||
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
@@ -265,6 +280,11 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
|
||||
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAuthenticationMethods(token, authenticationMethod)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
|
||||
}
|
||||
|
||||
for k, v := range userClaims {
|
||||
err = token.Set(k, v)
|
||||
if err != nil {
|
||||
@@ -283,8 +303,8 @@ func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, no
|
||||
}
|
||||
|
||||
// GenerateIDToken creates and signs an ID token
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||
token, err := s.BuildIDToken(userClaims, clientID, nonce)
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (string, error) {
|
||||
token, err := s.BuildIDToken(userClaims, clientID, nonce, authenticationMethod)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -332,7 +352,7 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
}
|
||||
|
||||
// BuildOAuthAccessToken creates an OAuth access token with all claims
|
||||
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jwt.Token, error) {
|
||||
func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string, authenticationMethod string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
@@ -355,12 +375,17 @@ func (s *JwtService) BuildOAuthAccessToken(user model.User, clientID string) (jw
|
||||
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAuthenticationMethods(token, authenticationMethod)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GenerateOAuthAccessToken creates and signs an OAuth access token
|
||||
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string) (string, error) {
|
||||
token, err := s.BuildOAuthAccessToken(user, clientID)
|
||||
func (s *JwtService) GenerateOAuthAccessToken(user model.User, clientID string, authenticationMethod string) (string, error) {
|
||||
token, err := s.BuildOAuthAccessToken(user, clientID, authenticationMethod)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -534,6 +559,27 @@ func GetIsAdmin(token jwt.Token) (bool, error) {
|
||||
return isAdmin, nil
|
||||
}
|
||||
|
||||
// GetAuthenticationMethod returns the first authentication method in the "amr" claim in the token
|
||||
func GetAuthenticationMethod(token jwt.Token) (string, error) {
|
||||
if !token.Has(AuthenticationMethodsClaim) {
|
||||
return "", nil
|
||||
}
|
||||
var rawAuthenticationMethods []any
|
||||
err := token.Get(AuthenticationMethodsClaim, &rawAuthenticationMethods)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get '%s' claim from token: %w", AuthenticationMethodsClaim, err)
|
||||
}
|
||||
|
||||
if len(rawAuthenticationMethods) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
authenticationMethod, ok := rawAuthenticationMethods[0].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid '%s' claim in token: expected array of strings", AuthenticationMethodsClaim)
|
||||
}
|
||||
return authenticationMethod, nil
|
||||
}
|
||||
|
||||
// SetTokenType sets the "type" claim in the token
|
||||
func SetTokenType(token jwt.Token, tokenType string) error {
|
||||
if tokenType == "" {
|
||||
@@ -551,6 +597,14 @@ func SetIsAdmin(token jwt.Token, isAdmin bool) error {
|
||||
return token.Set(IsAdminClaim, isAdmin)
|
||||
}
|
||||
|
||||
// SetAuthenticationMethods sets the authentication method references claim in the token
|
||||
func SetAuthenticationMethods(token jwt.Token, authenticationMethod string) error {
|
||||
if authenticationMethod == "" {
|
||||
return nil
|
||||
}
|
||||
return token.Set(AuthenticationMethodsClaim, []string{authenticationMethod})
|
||||
}
|
||||
|
||||
// SetAudienceString sets the "aud" claim with a value that is a string, and not an array
|
||||
// This is permitted by RFC 7519, and it's done here for backwards-compatibility
|
||||
func SetAudienceString(token jwt.Token, audience string) error {
|
||||
|
||||
@@ -174,6 +174,7 @@ func TestJwtService_Init(t *testing.T) {
|
||||
_ = assert.True(t, ok) &&
|
||||
assert.Equal(t, origKeyID, loadedKeyID, "Loaded key should have the same ID as the original")
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestJwtService_GetPublicJWK(t *testing.T) {
|
||||
@@ -308,7 +309,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(user)
|
||||
tokenString, err := service.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err, "Failed to generate access token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -321,6 +322,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
isAdmin, err := GetIsAdmin(claims)
|
||||
_ = assert.NoError(t, err, "Failed to get isAdmin claim") &&
|
||||
assert.False(t, isAdmin, "isAdmin should be false")
|
||||
authenticationMethod, err := GetAuthenticationMethod(claims)
|
||||
_ = assert.NoError(t, err, "Failed to get amr claim") &&
|
||||
assert.Empty(t, authenticationMethod, "amr should be empty when not specified")
|
||||
audience, ok := claims.Audience()
|
||||
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||
assert.Equal(t, []string{service.envConfig.AppURL}, audience, "Audience should contain the app URL")
|
||||
@@ -344,7 +348,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(adminUser)
|
||||
tokenString, err := service.GenerateAccessToken(adminUser, "")
|
||||
require.NoError(t, err, "Failed to generate access token")
|
||||
|
||||
claims, err := service.VerifyAccessToken(tokenString)
|
||||
@@ -358,6 +362,24 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
assert.Equal(t, adminUser.ID, subject, "Token subject should match user ID")
|
||||
})
|
||||
|
||||
t.Run("sets authentication method references claim when provided", func(t *testing.T) {
|
||||
service, _, _ := setupJwtService(t, mockConfig)
|
||||
|
||||
user := model.User{
|
||||
Base: model.Base{ID: "user-with-auth-method"},
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(user, AuthenticationMethodPhishingResistant)
|
||||
require.NoError(t, err, "Failed to generate access token")
|
||||
|
||||
claims, err := service.VerifyAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated token")
|
||||
|
||||
authenticationMethod, err := GetAuthenticationMethod(claims)
|
||||
_ = assert.NoError(t, err, "Failed to get amr claim") &&
|
||||
assert.Equal(t, AuthenticationMethodPhishingResistant, authenticationMethod, "amr should match")
|
||||
})
|
||||
|
||||
t.Run("uses session duration from config", func(t *testing.T) {
|
||||
customMockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
|
||||
@@ -368,7 +390,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
Base: model.Base{ID: "user456"},
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(user)
|
||||
tokenString, err := service.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err, "Failed to generate access token")
|
||||
|
||||
claims, err := service.VerifyAccessToken(tokenString)
|
||||
@@ -396,7 +418,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(user)
|
||||
tokenString, err := service.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err, "Failed to generate access token with Ed25519 key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -433,7 +455,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(user)
|
||||
tokenString, err := service.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err, "Failed to generate access token with ECDSA key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -470,7 +492,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
tokenString, err := service.GenerateAccessToken(user)
|
||||
tokenString, err := service.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err, "Failed to generate access token with RSA key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -508,7 +530,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "test-client-123"
|
||||
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
|
||||
require.NoError(t, err, "Failed to generate ID token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -593,7 +615,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
const clientID = "test-client-456"
|
||||
nonce := "random-nonce-value"
|
||||
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce)
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce, "")
|
||||
require.NoError(t, err, "Failed to generate ID token with nonce")
|
||||
|
||||
publicKey, err := service.GetPublicJWK()
|
||||
@@ -614,7 +636,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
userClaims := map[string]any{
|
||||
"sub": "user789",
|
||||
}
|
||||
tokenString, err := service.GenerateIDToken(userClaims, "client-789", "")
|
||||
tokenString, err := service.GenerateIDToken(userClaims, "client-789", "", "")
|
||||
require.NoError(t, err, "Failed to generate ID token")
|
||||
|
||||
service.envConfig.AppURL = "https://wrong-issuer.com"
|
||||
@@ -640,7 +662,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "eddsa-client-123"
|
||||
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
|
||||
require.NoError(t, err, "Failed to generate ID token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -677,7 +699,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "ecdsa-client-123"
|
||||
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
|
||||
require.NoError(t, err, "Failed to generate ID token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -715,7 +737,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "rsa-client-123"
|
||||
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "")
|
||||
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
|
||||
require.NoError(t, err, "Failed to generate ID token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -745,7 +767,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "test-client-123"
|
||||
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -772,6 +794,25 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
assert.InDelta(t, 0, timeDiff, 1.0, "Token should expire in approximately 1 hour")
|
||||
})
|
||||
|
||||
t.Run("sets authentication method references claim when provided", func(t *testing.T) {
|
||||
service, _, _ := setupJwtService(t, mockConfig)
|
||||
|
||||
user := model.User{
|
||||
Base: model.Base{ID: "oauth-amr-user"},
|
||||
}
|
||||
const clientID = "test-client-amr"
|
||||
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, AuthenticationMethodPhishingResistant)
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
|
||||
claims, err := service.VerifyOAuthAccessToken(tokenString)
|
||||
require.NoError(t, err, "Failed to verify generated OAuth access token")
|
||||
|
||||
authenticationMethod, err := GetAuthenticationMethod(claims)
|
||||
_ = assert.NoError(t, err, "Failed to get amr claim") &&
|
||||
assert.Equal(t, AuthenticationMethodPhishingResistant, authenticationMethod, "amr should match")
|
||||
})
|
||||
|
||||
t.Run("fails verification for expired token", func(t *testing.T) {
|
||||
service, _, _ := setupJwtService(t, mockConfig)
|
||||
|
||||
@@ -805,7 +846,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
user := model.User{Base: model.Base{ID: "user789"}}
|
||||
const clientID = "test-client-789"
|
||||
|
||||
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID)
|
||||
tokenString, err := service1.GenerateOAuthAccessToken(user, clientID, "")
|
||||
require.NoError(t, err, "Failed to generate OAuth access token")
|
||||
|
||||
_, err = service2.VerifyOAuthAccessToken(tokenString)
|
||||
@@ -828,7 +869,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "eddsa-oauth-client"
|
||||
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -865,7 +906,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "ecdsa-oauth-client"
|
||||
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
@@ -902,7 +943,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
||||
}
|
||||
const clientID = "rsa-oauth-client"
|
||||
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID)
|
||||
tokenString, err := service.GenerateOAuthAccessToken(user, clientID, "")
|
||||
require.NoError(t, err, "Failed to generate OAuth access token with key")
|
||||
assert.NotEmpty(t, tokenString, "Token should not be empty")
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID string, authenticationMethod string, ipAddress, userAgent string) (string, string, error) {
|
||||
tx := s.db.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
@@ -179,7 +179,7 @@ func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClie
|
||||
}
|
||||
|
||||
// Create the authorization code
|
||||
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
|
||||
code, err := s.createAuthorizationCode(ctx, input.ClientID, userID, input.Scope, authenticationMethod, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod, tx)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -312,17 +312,17 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
||||
}
|
||||
|
||||
// Explicitly use the input clientID for the audience claim to ensure consistency
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce)
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce, deviceAuth.AuthenticationMethod)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, tx)
|
||||
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, deviceAuth.AuthenticationMethod, tx)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(deviceAuth.User, input.ClientID, deviceAuth.AuthenticationMethod)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -363,7 +363,7 @@ func (s *OidcService) createTokenFromClientCredentials(ctx context.Context, inpu
|
||||
audClaim = input.Resource
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim, "")
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -411,18 +411,20 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce)
|
||||
authenticationMethod := authorizationCodeMetaData.AuthenticationMethod
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce, authenticationMethod)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Generate a refresh token
|
||||
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, tx)
|
||||
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, authenticationMethod, tx)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID)
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(authorizationCodeMetaData.User, input.ClientID, authenticationMethod)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -500,7 +502,8 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
}
|
||||
|
||||
// Generate a new access token
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID)
|
||||
authenticationMethods := storedRefreshToken.AuthenticationMethod
|
||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -513,13 +516,13 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
||||
|
||||
// Generate a new ID token
|
||||
// There's no nonce here because we don't have one with the refresh token, but that's not required
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "", authenticationMethods)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
|
||||
// Generate a new refresh token and invalidate the old one
|
||||
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, authenticationMethods, tx)
|
||||
if err != nil {
|
||||
return CreatedTokens{}, err
|
||||
}
|
||||
@@ -1140,7 +1143,7 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
|
||||
return callbackURL, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
|
||||
func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, nonce string, codeChallenge string, codeChallengeMethod string, tx *gorm.DB) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1154,6 +1157,7 @@ func (s *OidcService) createAuthorizationCode(ctx context.Context, clientID stri
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
AuthenticationMethod: authenticationMethod,
|
||||
Nonce: nonce,
|
||||
CodeChallenge: &codeChallenge,
|
||||
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
|
||||
@@ -1297,7 +1301,7 @@ func (s *OidcService) CreateDeviceAuthorization(ctx context.Context, input dto.O
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, ipAddress string, userAgent string) error {
|
||||
func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, userID string, authenticationMethod string, ipAddress string, userAgent string) error {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
@@ -1346,6 +1350,7 @@ func (s *OidcService) VerifyDeviceCode(ctx context.Context, userCode string, use
|
||||
}
|
||||
|
||||
deviceAuth.UserID = &userID
|
||||
deviceAuth.AuthenticationMethod = authenticationMethod
|
||||
deviceAuth.IsAuthorized = true
|
||||
|
||||
err = tx.
|
||||
@@ -1549,7 +1554,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
|
||||
return dtos, response, err
|
||||
}
|
||||
|
||||
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, tx *gorm.DB) (string, error) {
|
||||
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, tx *gorm.DB) (string, error) {
|
||||
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1560,11 +1565,12 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
|
||||
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
|
||||
|
||||
m := model.OidcRefreshToken{
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
|
||||
Token: refreshTokenHash,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
|
||||
Token: refreshTokenHash,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
AuthenticationMethod: authenticationMethod,
|
||||
}
|
||||
|
||||
err = tx.
|
||||
@@ -1780,7 +1786,7 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
|
||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string, authenticationMethod string) (*dto.OidcClientPreviewDto, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
@@ -1816,12 +1822,12 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
|
||||
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "", authenticationMethod)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID)
|
||||
accessToken, err := s.jwtService.BuildOAuthAccessToken(user, clientID, authenticationMethod)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -562,6 +562,46 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) {
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||
})
|
||||
jwtService, db, _ := setupJwtService(t, mockConfig)
|
||||
service := &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
}
|
||||
authenticationMethod := AuthenticationMethodPhishingResistant
|
||||
|
||||
t.Run("stores authentication method on authorization codes", func(t *testing.T) {
|
||||
code, err := service.createAuthorizationCode(
|
||||
t.Context(),
|
||||
"amr-client",
|
||||
"amr-user",
|
||||
"openid profile",
|
||||
authenticationMethod,
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
db,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
var authorizationCode model.OidcAuthorizationCode
|
||||
require.NoError(t, db.First(&authorizationCode, "code = ?", code).Error)
|
||||
assert.Equal(t, authenticationMethod, authorizationCode.AuthenticationMethod)
|
||||
})
|
||||
|
||||
t.Run("stores authentication methods on refresh tokens", func(t *testing.T) {
|
||||
_, err := service.createRefreshToken(t.Context(), "amr-client", "amr-user", "openid profile", authenticationMethod, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
var refreshToken model.OidcRefreshToken
|
||||
require.NoError(t, db.First(&refreshToken, "client_id = ? AND user_id = ?", "amr-client", "amr-user").Error)
|
||||
assert.Equal(t, authenticationMethod, refreshToken.AuthenticationMethod)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateCodeVerifier_Plain(t *testing.T) {
|
||||
require.False(t, validateCodeVerifier("", "", false))
|
||||
require.False(t, validateCodeVerifier("", "", true))
|
||||
|
||||
@@ -197,7 +197,7 @@ func (s *OneTimeAccessService) ExchangeOneTimeAccessToken(ctx context.Context, t
|
||||
return model.User{}, "", &common.DeviceCodeInvalid{}
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User, AuthenticationMethodOneTimePassword)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func (s *UserSignUpService) SignUp(ctx context.Context, signupData dto.SignUpDto
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(user)
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(user, "")
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func (s *UserSignUpService) SignUpInitialAdmin(ctx context.Context, signUpData d
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
token, err := s.jwtService.GenerateAccessToken(user)
|
||||
token, err := s.jwtService.GenerateAccessToken(user, AuthenticationMethodOneTimePassword)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
@@ -266,7 +266,7 @@ func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, cre
|
||||
return model.User{}, "", &common.UserDisabledError{}
|
||||
}
|
||||
|
||||
token, err := s.jwtService.GenerateAccessToken(*user)
|
||||
token, err := s.jwtService.GenerateAccessToken(*user, AuthenticationMethodPhishingResistant)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
@@ -389,6 +389,14 @@ func (s *WebAuthnService) CreateReauthenticationTokenWithAccessToken(ctx context
|
||||
return "", errors.New("access token does not contain user ID")
|
||||
}
|
||||
|
||||
authenticationMethod, err := GetAuthenticationMethod(token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if authenticationMethod != AuthenticationMethodPhishingResistant {
|
||||
return "", &common.ReauthenticationRequiredError{}
|
||||
}
|
||||
|
||||
// Check if token is issued less than a minute ago
|
||||
tokenExpiration, ok := token.IssuedAt()
|
||||
if !ok || time.Since(tokenExpiration) > time.Minute {
|
||||
|
||||
68
backend/internal/service/webauthn_service_test.go
Normal file
68
backend/internal/service/webauthn_service_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
)
|
||||
|
||||
func TestCreateReauthenticationTokenWithAccessToken(t *testing.T) {
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||
})
|
||||
|
||||
setupService := func(t *testing.T) (*WebAuthnService, model.User) {
|
||||
t.Helper()
|
||||
|
||||
jwtService, db, _ := setupJwtService(t, mockConfig)
|
||||
user := model.User{
|
||||
Base: model.Base{ID: "reauth-user"},
|
||||
Username: "reauth-user",
|
||||
}
|
||||
require.NoError(t, db.Create(&user).Error)
|
||||
|
||||
return &WebAuthnService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
}, user
|
||||
}
|
||||
|
||||
t.Run("accepts a fresh access token from WebAuthn login", func(t *testing.T) {
|
||||
service, user := setupService(t)
|
||||
accessToken, err := service.jwtService.GenerateAccessToken(user, AuthenticationMethodPhishingResistant)
|
||||
require.NoError(t, err)
|
||||
|
||||
reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, reauthenticationToken)
|
||||
})
|
||||
|
||||
t.Run("rejects a fresh access token from one-time access login", func(t *testing.T) {
|
||||
service, user := setupService(t)
|
||||
accessToken, err := service.jwtService.GenerateAccessToken(user, AuthenticationMethodOneTimePassword)
|
||||
require.NoError(t, err)
|
||||
|
||||
reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken)
|
||||
|
||||
assert.Empty(t, reauthenticationToken)
|
||||
require.Error(t, err)
|
||||
assert.ErrorAs(t, err, new(*common.ReauthenticationRequiredError))
|
||||
})
|
||||
|
||||
t.Run("rejects a fresh access token without an authentication method", func(t *testing.T) {
|
||||
service, user := setupService(t)
|
||||
accessToken, err := service.jwtService.GenerateAccessToken(user, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
reauthenticationToken, err := service.CreateReauthenticationTokenWithAccessToken(t.Context(), accessToken)
|
||||
|
||||
assert.Empty(t, reauthenticationToken)
|
||||
require.Error(t, err)
|
||||
assert.ErrorAs(t, err, new(*common.ReauthenticationRequiredError))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method;
|
||||
ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method;
|
||||
ALTER TABLE oidc_device_codes DROP COLUMN authentication_method;
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE oidc_authorization_codes
|
||||
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE oidc_refresh_tokens
|
||||
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE oidc_device_codes
|
||||
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
|
||||
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method;
|
||||
ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method;
|
||||
ALTER TABLE oidc_device_codes DROP COLUMN authentication_method;
|
||||
@@ -0,0 +1,6 @@
|
||||
ALTER TABLE oidc_authorization_codes
|
||||
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE oidc_refresh_tokens
|
||||
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE oidc_device_codes
|
||||
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"provider": "sqlite",
|
||||
"version": 20260109090200,
|
||||
"version": 20260418120000,
|
||||
"tableOrder": ["users", "user_groups", "oidc_clients", "signup_tokens"],
|
||||
"tables": {
|
||||
"api_keys": [
|
||||
@@ -42,6 +42,7 @@
|
||||
"oidc_authorization_codes": [
|
||||
{
|
||||
"client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
||||
"authentication_method": "phr",
|
||||
"code": "auth-code",
|
||||
"code_challenge": null,
|
||||
"code_challenge_method_sha256": null,
|
||||
@@ -54,6 +55,7 @@
|
||||
},
|
||||
{
|
||||
"client_id": "c48232ff-ff65-45ed-ae96-7afa8a9b443b",
|
||||
"authentication_method": "phr",
|
||||
"code": "federated",
|
||||
"code_challenge": null,
|
||||
"code_challenge_method_sha256": null,
|
||||
@@ -168,6 +170,7 @@
|
||||
"oidc_refresh_tokens": [
|
||||
{
|
||||
"client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
||||
"authentication_method": "phr",
|
||||
"created_at": "2025-11-25T12:39:02Z",
|
||||
"expires_at": "2025-11-26T12:39:02Z",
|
||||
"id": "4928604e-e689-410c-9b25-5b9b6db9e46e",
|
||||
|
||||
Reference in New Issue
Block a user