mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-18 10:59:53 +00:00
fix: access token renewal bypasses important checks
This commit is contained in:
@@ -534,6 +534,43 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if storedRefreshToken.User.Disabled {
|
||||||
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var authorizedClient model.UserAuthorizedOidcClient
|
||||||
|
err = tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Where("user_id = ? AND client_id = ?", storedRefreshToken.UserID, input.ClientID).
|
||||||
|
First(&authorizedClient).
|
||||||
|
Error
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
err = tx.WithContext(ctx).Delete(&storedRefreshToken).Error
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit().Error
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
|
} else if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.IsGroupRestricted {
|
||||||
|
err = tx.WithContext(ctx).Model(client).Association("AllowedUserGroups").Find(&client.AllowedUserGroups)
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !IsUserGroupAllowedToAuthorize(storedRefreshToken.User, *client) {
|
||||||
|
return CreatedTokens{}, &common.OidcAccessDeniedError{}
|
||||||
|
}
|
||||||
|
|
||||||
// Generate a new access token
|
// Generate a new access token
|
||||||
authenticationMethods := storedRefreshToken.AuthenticationMethod
|
authenticationMethods := storedRefreshToken.AuthenticationMethod
|
||||||
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods)
|
accessToken, err := s.jwtService.GenerateOAuthAccessToken(storedRefreshToken.User, input.ClientID, authenticationMethods)
|
||||||
@@ -1500,6 +1537,15 @@ func (s *OidcService) RevokeAuthorizedClient(ctx context.Context, userID string,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = tx.
|
||||||
|
WithContext(ctx).
|
||||||
|
Where("user_id = ? AND client_id = ?", userID, clientID).
|
||||||
|
Delete(&model.OidcRefreshToken{}).
|
||||||
|
Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
err = tx.Commit().Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
@@ -562,6 +563,140 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOidcServiceRefreshTokenAuthorizationState(t *testing.T) {
|
||||||
|
newFixture := func(t *testing.T, isGroupRestricted bool) (*OidcService, *gorm.DB, model.User, model.OidcClient, string, string, *model.UserGroup) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
db := testutils.NewDatabaseForTest(t)
|
||||||
|
common.EnvConfig.EncryptionKey = []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
|
||||||
|
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||||
|
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||||
|
})
|
||||||
|
jwtService, err := NewJwtService(t.Context(), db, mockConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
service := &OidcService{
|
||||||
|
db: db,
|
||||||
|
jwtService: jwtService,
|
||||||
|
appConfigService: mockConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
email := "refresh-token-user@example.com"
|
||||||
|
user := model.User{
|
||||||
|
Username: "refresh-token-user",
|
||||||
|
Email: &email,
|
||||||
|
EmailVerified: true,
|
||||||
|
FirstName: "Refresh",
|
||||||
|
LastName: "User",
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(&user).Error)
|
||||||
|
|
||||||
|
client, err := service.CreateClient(t.Context(), dto.OidcClientCreateDto{
|
||||||
|
OidcClientUpdateDto: dto.OidcClientUpdateDto{
|
||||||
|
Name: "Refresh Token Client",
|
||||||
|
CallbackURLs: []string{"https://example.com/callback"},
|
||||||
|
IsGroupRestricted: isGroupRestricted,
|
||||||
|
},
|
||||||
|
}, user.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientSecret, err := service.CreateClientSecret(t.Context(), client.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var userGroup *model.UserGroup
|
||||||
|
if isGroupRestricted {
|
||||||
|
group := model.UserGroup{
|
||||||
|
FriendlyName: "Allowed Group",
|
||||||
|
Name: "allowed-group",
|
||||||
|
}
|
||||||
|
require.NoError(t, db.Create(&group).Error)
|
||||||
|
require.NoError(t, db.Model(&user).Association("UserGroups").Append(&group))
|
||||||
|
require.NoError(t, db.Model(&client).Association("AllowedUserGroups").Append(&group))
|
||||||
|
userGroup = &group
|
||||||
|
}
|
||||||
|
|
||||||
|
scope := "openid profile email groups"
|
||||||
|
require.NoError(t, db.Create(&model.UserAuthorizedOidcClient{
|
||||||
|
UserID: user.ID,
|
||||||
|
ClientID: client.ID,
|
||||||
|
Scope: scope,
|
||||||
|
}).Error)
|
||||||
|
|
||||||
|
refreshToken, err := service.createRefreshToken(t.Context(), client.ID, user.ID, scope, AuthenticationMethodPhishingResistant, db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return service, db, user, client, clientSecret, refreshToken, userGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshInput := func(client model.OidcClient, clientSecret string, refreshToken string) dto.OidcCreateTokensDto {
|
||||||
|
return dto.OidcCreateTokensDto{
|
||||||
|
GrantType: GrantTypeRefreshToken,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
ClientID: client.ID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("rejects refresh token after authorization revocation", func(t *testing.T) {
|
||||||
|
service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false)
|
||||||
|
|
||||||
|
err := service.RevokeAuthorizedClient(t.Context(), user.ID, client.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var refreshTokenCount int64
|
||||||
|
require.NoError(t, db.Model(&model.OidcRefreshToken{}).
|
||||||
|
Where("user_id = ? AND client_id = ?", user.ID, client.ID).
|
||||||
|
Count(&refreshTokenCount).Error)
|
||||||
|
assert.Zero(t, refreshTokenCount)
|
||||||
|
|
||||||
|
_, err = service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects and deletes stale refresh token without authorization record", func(t *testing.T) {
|
||||||
|
service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false)
|
||||||
|
|
||||||
|
require.NoError(t, db.
|
||||||
|
Where("user_id = ? AND client_id = ?", user.ID, client.ID).
|
||||||
|
Delete(&model.UserAuthorizedOidcClient{}).Error)
|
||||||
|
|
||||||
|
_, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{})
|
||||||
|
|
||||||
|
var refreshTokenCount int64
|
||||||
|
require.NoError(t, db.Model(&model.OidcRefreshToken{}).
|
||||||
|
Where("user_id = ? AND client_id = ?", user.ID, client.ID).
|
||||||
|
Count(&refreshTokenCount).Error)
|
||||||
|
assert.Zero(t, refreshTokenCount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects refresh token for disabled user", func(t *testing.T) {
|
||||||
|
service, db, user, client, clientSecret, refreshToken, _ := newFixture(t, false)
|
||||||
|
|
||||||
|
require.NoError(t, db.Model(&model.User{}).
|
||||||
|
Where("id = ?", user.ID).
|
||||||
|
Update("disabled", true).Error)
|
||||||
|
|
||||||
|
_, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcInvalidRefreshTokenError{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects refresh token after user leaves allowed group", func(t *testing.T) {
|
||||||
|
service, db, user, client, clientSecret, refreshToken, userGroup := newFixture(t, true)
|
||||||
|
require.NotNil(t, userGroup)
|
||||||
|
|
||||||
|
require.NoError(t, db.Model(&user).Association("UserGroups").Delete(userGroup))
|
||||||
|
|
||||||
|
_, err := service.createTokenFromRefreshToken(t.Context(), refreshInput(client, clientSecret, refreshToken))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcAccessDeniedError{})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) {
|
func TestOidcServiceAuthenticationMethodsPersistence(t *testing.T) {
|
||||||
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||||
SessionDuration: model.AppConfigVariable{Value: "60"},
|
SessionDuration: model.AppConfigVariable{Value: "60"},
|
||||||
|
|||||||
Reference in New Issue
Block a user