mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-16 18:09:52 +00:00
fix(oidc): delete refresh tokens on end-session to prevent reuse after logout (#1458)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> Signed-off-by: wucm667 <stevenwucongmin@gmail.com>
This commit is contained in:
@@ -1218,9 +1218,12 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
|
||||
return "", &common.OidcClientIdNotMatchingError{}
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
defer tx.Rollback()
|
||||
|
||||
// Check if the user has authorized the client before
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
err = s.db.
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("Client").
|
||||
First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).
|
||||
@@ -1229,6 +1232,11 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
|
||||
return "", &common.OidcMissingAuthorizationError{}
|
||||
}
|
||||
|
||||
// Delete all refresh tokens for this user
|
||||
if err := tx.WithContext(ctx).Where("user_id = ?", userID).Delete(&model.OidcRefreshToken{}).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If the client has no logout callback URLs, return an error
|
||||
if len(userAuthorizedOIDCClient.Client.LogoutCallbackURLs) == 0 {
|
||||
return "", &common.OidcNoCallbackURLError{}
|
||||
@@ -1239,6 +1247,10 @@ func (s *OidcService) ValidateEndSession(ctx context.Context, input dto.OidcLogo
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return "", fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
return callbackURL, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1249,3 +1249,80 @@ func TestPromptParameterConflicts(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOidcService_ValidateEndSession_DeletesRefreshTokens(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
|
||||
|
||||
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||
})
|
||||
mockJwtService, err := NewJwtService(t.Context(), db, mockConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
oidcService := &OidcService{
|
||||
db: db,
|
||||
jwtService: mockJwtService,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
userID := "test-user-123"
|
||||
clientID := "test-client-456"
|
||||
|
||||
userEmail := "test@example.com"
|
||||
user := model.User{
|
||||
Base: model.Base{ID: userID},
|
||||
Email: &userEmail,
|
||||
}
|
||||
require.NoError(t, db.Create(&user).Error)
|
||||
|
||||
client := model.OidcClient{
|
||||
Base: model.Base{ID: clientID},
|
||||
Name: "Test Client",
|
||||
LogoutCallbackURLs: []string{"https://example.com/logout"},
|
||||
}
|
||||
require.NoError(t, db.Create(&client).Error)
|
||||
|
||||
authorization := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
}
|
||||
require.NoError(t, db.Create(&authorization).Error)
|
||||
|
||||
refreshToken1 := model.OidcRefreshToken{
|
||||
Token: "refresh-token-1",
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
}
|
||||
require.NoError(t, db.Create(&refreshToken1).Error)
|
||||
|
||||
refreshToken2 := model.OidcRefreshToken{
|
||||
Token: "refresh-token-2",
|
||||
UserID: userID,
|
||||
ClientID: "other-client",
|
||||
}
|
||||
require.NoError(t, db.Create(&refreshToken2).Error)
|
||||
|
||||
userClaims := map[string]any{
|
||||
"sub": userID,
|
||||
"name": "Test User",
|
||||
"email": userEmail,
|
||||
}
|
||||
idToken, err := mockJwtService.GenerateIDToken(userClaims, clientID, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
input := dto.OidcLogoutDto{
|
||||
IdTokenHint: idToken,
|
||||
ClientId: clientID,
|
||||
PostLogoutRedirectUri: "https://example.com/logout",
|
||||
}
|
||||
|
||||
callbackURL, err := oidcService.ValidateEndSession(ctx, input, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://example.com/logout", callbackURL)
|
||||
|
||||
var remainingTokens []model.OidcRefreshToken
|
||||
err = db.Where("user_id = ?", userID).Find(&remainingTokens).Error
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, remainingTokens, "all refresh tokens for the user should be deleted")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user