feat: delete OAuth refresh token on RP initiated logout (#1480)

This commit is contained in:
Elias Schneider
2026-05-19 17:05:44 +02:00
committed by GitHub
parent b9fdd530c0
commit 9dd3d319cf
16 changed files with 230 additions and 39 deletions

View File

@@ -311,6 +311,12 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
// The validation was successful, so we can log out and redirect the user to the callback URL without confirmation
cookie.AddAccessTokenCookie(c, 0, "")
// Callback URL can be empty if none is configured
if callbackURL == "" {
c.Redirect(http.StatusFound, common.EnvConfig.AppURL+"/logout")
return
}
logoutCallbackURL, _ := url.Parse(callbackURL)
if input.State != "" {
q := logoutCallbackURL.Query()

View File

@@ -79,6 +79,7 @@ type OidcRefreshToken struct {
Base
Token string
IdTokenJti *string
ExpiresAt datatype.DateTime
Scope string
AuthenticationMethod string

View File

@@ -271,6 +271,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
refreshToken := model.OidcRefreshToken{
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
IdTokenJti: new("dd75f6f6-ce0a-44b7-a645-7de390ccd2fa"),
AuthenticationMethod: AuthenticationMethodPhishingResistant,
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
Scope: "openid profile email",

View File

@@ -258,64 +258,65 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
}
// BuildIDToken creates an ID token with all claims
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (jwt.Token, error) {
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (jwt.Token, string, error) {
now := time.Now()
jti := uuid.New().String()
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(s.envConfig.AppURL).
JwtID(uuid.New().String()).
JwtID(jti).
Build()
if err != nil {
return nil, fmt.Errorf("failed to build token: %w", err)
return nil, "", fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
return nil, "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
err = SetTokenType(token, IDTokenJWTType)
if err != nil {
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
return nil, "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
}
err = SetAuthenticationMethods(token, authenticationMethod)
if err != nil {
return nil, fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
return nil, "", fmt.Errorf("failed to set '%s' claim in token: %w", AuthenticationMethodsClaim, err)
}
for k, v := range userClaims {
err = token.Set(k, v)
if err != nil {
return nil, fmt.Errorf("failed to set claim '%s': %w", k, err)
return nil, "", fmt.Errorf("failed to set claim '%s': %w", k, err)
}
}
if nonce != "" {
err = token.Set("nonce", nonce)
if err != nil {
return nil, fmt.Errorf("failed to set claim 'nonce': %w", err)
return nil, "", fmt.Errorf("failed to set claim 'nonce': %w", err)
}
}
return token, nil
return token, jti, nil
}
// GenerateIDToken creates and signs an ID token
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (string, error) {
token, err := s.BuildIDToken(userClaims, clientID, nonce, authenticationMethod)
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string, authenticationMethod string) (signedToken, jti string, err error) {
token, jti, err := s.BuildIDToken(userClaims, clientID, nonce, authenticationMethod)
if err != nil {
return "", err
return "", "", err
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
return "", "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
return string(signed), jti, nil
}
func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) (jwt.Token, error) {

View File

@@ -530,9 +530,10 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "test-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
tokenString, jti, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token")
assert.NotEmpty(t, tokenString, "Token should not be empty")
assert.Regexp(t, uuidRegexPattern, jti, "Returned JWT ID is not a UUID")
claims, err := service.VerifyIdToken(tokenString, false)
require.NoError(t, err, "Failed to verify generated ID token")
@@ -549,6 +550,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
jwtID, ok := claims.JwtID()
_ = assert.True(t, ok, "JWT ID not found in token") &&
assert.Regexp(t, uuidRegexPattern, jwtID, "JWT ID is not a UUID")
assert.Equal(t, jti, jwtID, "Returned JWT ID should match token claim")
expectedExp := time.Now().Add(1 * time.Hour)
expiration, ok := claims.Expiration()
@@ -615,7 +617,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
const clientID = "test-client-456"
nonce := "random-nonce-value"
tokenString, err := service.GenerateIDToken(userClaims, clientID, nonce, "")
tokenString, _, err := service.GenerateIDToken(userClaims, clientID, nonce, "")
require.NoError(t, err, "Failed to generate ID token with nonce")
publicKey, err := service.GetPublicJWK()
@@ -636,7 +638,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
userClaims := map[string]any{
"sub": "user789",
}
tokenString, err := service.GenerateIDToken(userClaims, "client-789", "", "")
tokenString, _, err := service.GenerateIDToken(userClaims, "client-789", "", "")
require.NoError(t, err, "Failed to generate ID token")
service.envConfig.AppURL = "https://wrong-issuer.com"
@@ -662,7 +664,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "eddsa-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
tokenString, _, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -699,7 +701,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "ecdsa-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
tokenString, _, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")
@@ -737,7 +739,7 @@ func TestGenerateVerifyIdToken(t *testing.T) {
}
const clientID = "rsa-client-123"
tokenString, err := service.GenerateIDToken(userClaims, clientID, "", "")
tokenString, _, err := service.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err, "Failed to generate ID token with key")
assert.NotEmpty(t, tokenString, "Token should not be empty")

View File

@@ -354,12 +354,12 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
}
// Explicitly use the input clientID for the audience claim to ensure consistency
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce, deviceAuth.AuthenticationMethod)
idToken, idTokenJti, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, deviceAuth.Nonce, deviceAuth.AuthenticationMethod)
if err != nil {
return CreatedTokens{}, err
}
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, deviceAuth.AuthenticationMethod, tx)
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, *deviceAuth.UserID, deviceAuth.Scope, deviceAuth.AuthenticationMethod, idTokenJti, tx)
if err != nil {
return CreatedTokens{}, err
}
@@ -455,13 +455,13 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
authenticationMethod := authorizationCodeMetaData.AuthenticationMethod
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce, authenticationMethod)
idToken, idTokenJti, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, authorizationCodeMetaData.Nonce, authenticationMethod)
if err != nil {
return CreatedTokens{}, err
}
// Generate a refresh token
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, authenticationMethod, tx)
refreshToken, err := s.createRefreshToken(ctx, input.ClientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope, authenticationMethod, idTokenJti, tx)
if err != nil {
return CreatedTokens{}, err
}
@@ -595,13 +595,13 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
// Generate a new ID token
// There's no nonce here because we don't have one with the refresh token, but that's not required
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "", authenticationMethods)
idToken, idTokenJti, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "", authenticationMethods)
if err != nil {
return CreatedTokens{}, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, authenticationMethods, tx)
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, authenticationMethods, idTokenJti, tx)
if err != nil {
return CreatedTokens{}, err
}
@@ -1197,7 +1197,7 @@ func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, in
}
// ValidateEndSession returns the logout callback URL for the client if all the validations pass
func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (string, error) {
func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogoutDto, userID string) (callbackURL string, err error) {
// If no ID token hint is provided, return an error
if input.IdTokenHint == "" {
return "", &common.TokenInvalidError{}
@@ -1219,9 +1219,22 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
return "", &common.OidcClientIdNotMatchingError{}
}
subject, ok := token.Subject()
if !ok || subject != userID {
return "", &common.TokenInvalidError{}
}
idTokenJti, ok := token.JwtID()
if !ok {
return "", &common.TokenInvalidError{}
}
tx := s.db.Begin()
defer tx.Rollback()
// Check if the user has authorized the client before
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
err = s.db.
err = tx.
WithContext(ctx).
Preload("Client").
First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).
@@ -1230,16 +1243,27 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
return "", &common.OidcMissingAuthorizationError{}
}
// If the client has no logout callback URLs, return an error
if len(userAuthorizedOIDCClient.Client.LogoutCallbackURLs) == 0 {
return "", &common.OidcNoCallbackURLError{}
// If the client has a callback URL, validate it
if len(userAuthorizedOIDCClient.Client.LogoutCallbackURLs) > 0 {
callbackURL, err = s.getLogoutCallbackURL(&userAuthorizedOIDCClient.Client, input.PostLogoutRedirectUri)
if err != nil {
return "", err
}
}
callbackURL, err := s.getLogoutCallbackURL(&userAuthorizedOIDCClient.Client, input.PostLogoutRedirectUri)
err = tx.
WithContext(ctx).
Where("user_id = ? AND client_id = ? AND id_token_jti = ?", userID, clientID[0], idTokenJti).
Delete(&model.OidcRefreshToken{}).
Error
if err != nil {
return "", err
}
if err := tx.Commit().Error; err != nil {
return "", fmt.Errorf("failed to commit transaction: %w", err)
}
return callbackURL, nil
}
@@ -1679,7 +1703,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
return dtos, response, err
}
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, tx *gorm.DB) (string, error) {
func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, userID string, scope string, authenticationMethod string, idTokenJti string, tx *gorm.DB) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
@@ -1692,6 +1716,7 @@ func (s *OidcService) createRefreshToken(ctx context.Context, clientID string, u
m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(RefreshTokenDuration)),
Token: refreshTokenHash,
IdTokenJti: &idTokenJti,
ClientID: clientID,
UserID: userID,
Scope: scope,
@@ -1951,7 +1976,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
return nil, err
}
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "", authenticationMethod)
idToken, _, err := s.jwtService.BuildIDToken(userClaims, clientID, "", authenticationMethod)
if err != nil {
return nil, err
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/storage"
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
@@ -623,7 +624,7 @@ func TestOidcServiceRefreshTokenAuthorizationState(t *testing.T) {
Scope: scope,
}).Error)
refreshToken, err := service.createRefreshToken(t.Context(), client.ID, user.ID, scope, AuthenticationMethodPhishingResistant, db)
refreshToken, err := service.createRefreshToken(t.Context(), client.ID, user.ID, scope, AuthenticationMethodPhishingResistant, "03f94e54-53c4-42f8-afe5-918ffd97a30e", db)
require.NoError(t, err)
return service, db, user, client, clientSecret, refreshToken, userGroup
@@ -728,7 +729,7 @@ func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) {
})
t.Run("stores authentication methods on refresh tokens", func(t *testing.T) {
_, err := service.createRefreshToken(t.Context(), "amr-client", "amr-user", "openid profile", authenticationMethod, db)
_, err := service.createRefreshToken(t.Context(), "amr-client", "amr-user", "openid profile", authenticationMethod, "03f94e54-53c4-42f8-afe5-918ffd97a30e", db)
require.NoError(t, err)
var refreshToken model.OidcRefreshToken
@@ -1251,6 +1252,115 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) {
})
}
func TestOidcService_ValidateEndSessionDeletesMatchingRefreshToken(t *testing.T) {
db := testutils.NewDatabaseForTest(t)
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"},
})
mockJwtService, err := NewJwtService(t.Context(), db, mockConfig)
require.NoError(t, err)
oidcService := &OidcService{
db: db,
jwtService: mockJwtService,
}
userID := "test-user-123"
clientID := "test-client-456"
otherClientID := "other-client-789"
otherIDTokenJti := "ac653f42-4781-49f2-bc7c-cc44503c3a1a" //nolint:gosec
userEmail := "test@example.com"
user := model.User{
Base: model.Base{ID: userID},
Email: &userEmail,
}
require.NoError(t, db.Create(&user).Error)
client := model.OidcClient{
Base: model.Base{ID: clientID},
Name: "Test Client",
LogoutCallbackURLs: []string{"https://example.com/logout"},
}
require.NoError(t, db.Create(&client).Error)
otherClient := model.OidcClient{
Base: model.Base{ID: otherClientID},
Name: "Other Client",
LogoutCallbackURLs: []string{"https://other.example.com/logout"},
}
require.NoError(t, db.Create(&otherClient).Error)
require.NoError(t, db.Create(&model.UserAuthorizedOidcClient{
UserID: userID,
ClientID: clientID,
}).Error)
userClaims := map[string]any{
"sub": userID,
"name": "Test User",
"email": userEmail,
}
idToken, idTokenJti, err := mockJwtService.GenerateIDToken(userClaims, clientID, "", "")
require.NoError(t, err)
refreshTokens := []model.OidcRefreshToken{
{
Token: "matching-refresh-token",
UserID: userID,
ClientID: clientID,
IdTokenJti: &idTokenJti,
ExpiresAt: datatype.DateTime(time.Now().Add(time.Hour)),
Scope: "openid profile",
},
{
Token: "same-client-different-session",
UserID: userID,
ClientID: clientID,
IdTokenJti: &otherIDTokenJti,
ExpiresAt: datatype.DateTime(time.Now().Add(time.Hour)),
Scope: "openid profile",
},
{
Token: "other-client-same-jti",
UserID: userID,
ClientID: otherClientID,
IdTokenJti: &idTokenJti,
ExpiresAt: datatype.DateTime(time.Now().Add(time.Hour)),
Scope: "openid profile",
},
{
Token: "legacy-refresh-token",
UserID: userID,
ClientID: clientID,
ExpiresAt: datatype.DateTime(time.Now().Add(time.Hour)),
Scope: "openid profile",
},
}
require.NoError(t, db.Create(&refreshTokens).Error)
callbackURL, err := oidcService.ValidateEndSession(t.Context(), dto.OidcLogoutDto{
IdTokenHint: idToken,
ClientId: clientID,
PostLogoutRedirectUri: "https://example.com/logout",
}, userID)
require.NoError(t, err)
assert.Equal(t, "https://example.com/logout", callbackURL)
var remainingTokens []model.OidcRefreshToken
require.NoError(t, db.Order("token").Find(&remainingTokens).Error)
remainingTokenValues := make([]string, len(remainingTokens))
for i, token := range remainingTokens {
remainingTokenValues[i] = token.Token
}
assert.ElementsMatch(t, []string{
"legacy-refresh-token",
"other-client-same-jti",
"same-client-different-session",
}, remainingTokenValues)
}
// Tests for prompt parameter parsing and handling
func TestParsePromptParameter(t *testing.T) {
t.Run("empty prompt returns empty slice", func(t *testing.T) {

View File

@@ -0,0 +1,4 @@
DROP INDEX IF EXISTS idx_oidc_refresh_tokens_id_token_jti;
ALTER TABLE oidc_refresh_tokens
DROP COLUMN id_token_jti;

View File

@@ -0,0 +1,5 @@
ALTER TABLE oidc_refresh_tokens
ADD COLUMN id_token_jti UUID;
CREATE INDEX idx_oidc_refresh_tokens_id_token_jti
ON oidc_refresh_tokens(user_id, client_id, id_token_jti);

View File

@@ -1,3 +1,10 @@
PRAGMA foreign_keys= OFF;
BEGIN;
ALTER TABLE oidc_authorization_codes DROP COLUMN authentication_method;
ALTER TABLE oidc_refresh_tokens DROP COLUMN authentication_method;
ALTER TABLE oidc_device_codes DROP COLUMN authentication_method;
COMMIT;
PRAGMA foreign_keys= ON;

View File

@@ -1,6 +1,12 @@
PRAGMA foreign_keys= OFF;
BEGIN;
ALTER TABLE oidc_authorization_codes
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
ALTER TABLE oidc_refresh_tokens
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
ALTER TABLE oidc_device_codes
ADD COLUMN authentication_method TEXT NOT NULL DEFAULT '';
COMMIT;
PRAGMA foreign_keys= ON;

View File

@@ -0,0 +1,10 @@
PRAGMA foreign_keys= OFF;
BEGIN;
DROP INDEX IF EXISTS idx_oidc_refresh_tokens_id_token_jti;
ALTER TABLE oidc_refresh_tokens
DROP COLUMN id_token_jti;
COMMIT;
PRAGMA foreign_keys= ON;

View File

@@ -0,0 +1,11 @@
PRAGMA foreign_keys= OFF;
BEGIN;
ALTER TABLE oidc_refresh_tokens
ADD COLUMN id_token_jti TEXT;
CREATE INDEX idx_oidc_refresh_tokens_id_token_jti
ON oidc_refresh_tokens(user_id, client_id, id_token_jti);
COMMIT;
PRAGMA foreign_keys= ON;

View File

@@ -1,6 +1,6 @@
{
"provider": "sqlite",
"version": 20260418120000,
"version": 20260518222000,
"tableOrder": ["users", "user_groups", "oidc_clients", "signup_tokens"],
"tables": {
"api_keys": [
@@ -171,6 +171,7 @@
{
"client_id": "3654a746-35d4-4321-ac61-0bdcff2b4055",
"authentication_method": "phr",
"id_token_jti": "dd75f6f6-ce0a-44b7-a645-7de390ccd2fa",
"created_at": "2025-11-25T12:39:02Z",
"expires_at": "2025-11-26T12:39:02Z",
"id": "4928604e-e689-410c-9b25-5b9b6db9e46e",

View File

@@ -130,7 +130,7 @@ test('End session without id token hint shows confirmation page', async ({ page
test('End session with id token hint redirects to callback URL', async ({ page }) => {
const client = oidcClients.nextcloud;
const idToken = await generateIdToken(users.tim, client.id);
const idToken = await generateIdToken("fe81c12a-7336-4aee-bebc-d901a873bf48", users.tim, client.id);
let redirectedCorrectly = false;
await page
.goto(

View File

@@ -13,11 +13,12 @@ type User = {
const privateKey = JSON.parse(PRIVATE_KEY_STRING);
const privateKeyImported = await jose.importJWK(privateKey, 'RS256');
export async function generateIdToken(user: User, clientId: string, expired = false) {
export async function generateIdToken(jti: string, user: User, clientId: string, expired = false) {
const now = Math.floor(Date.now() / 1000);
const expiration = expired ? now + 1 : now + 1000000000; // Either expired or valid for a long time
const payload = {
jti,
aud: clientId,
email: user.email,
email_verified: true,