mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-27 17:56:36 +00:00
feat: oidc client data preview (#624)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
@@ -48,6 +48,8 @@ func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMi
|
||||
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
|
||||
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
|
||||
|
||||
group.GET("/oidc/clients/:id/preview/:userId", authMiddleware.Add(), oc.getClientPreviewHandler)
|
||||
|
||||
group.POST("/oidc/device/authorize", oc.deviceAuthorizationHandler)
|
||||
group.POST("/oidc/device/verify", authMiddleware.WithAdminNotRequired().Add(), oc.verifyDeviceCodeHandler)
|
||||
group.GET("/oidc/device/info", authMiddleware.WithAdminNotRequired().Add(), oc.getDeviceCodeInfoHandler)
|
||||
@@ -721,3 +723,43 @@ func (oc *OidcController) getDeviceCodeInfoHandler(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, deviceCodeInfo)
|
||||
}
|
||||
|
||||
// getClientPreviewHandler godoc
|
||||
// @Summary Preview OIDC client data for user
|
||||
// @Description Get a preview of the OIDC data (ID token, access token, userinfo) that would be sent to the client for a specific user
|
||||
// @Tags OIDC
|
||||
// @Produce json
|
||||
// @Param id path string true "Client ID"
|
||||
// @Param userId path string true "User ID to preview data for"
|
||||
// @Param scopes query string false "Scopes to include in the preview (comma-separated)"
|
||||
// @Success 200 {object} dto.OidcClientPreviewDto "Preview data including ID token, access token, and userinfo payloads"
|
||||
// @Security BearerAuth
|
||||
// @Router /api/oidc/clients/{id}/preview/{userId} [get]
|
||||
func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
|
||||
clientID := c.Param("id")
|
||||
userID := c.Param("userId")
|
||||
scopes := c.Query("scopes")
|
||||
|
||||
if clientID == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "client ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if userID == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "user ID is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if scopes == "" {
|
||||
_ = c.Error(&common.ValidationError{Message: "scopes are required"})
|
||||
return
|
||||
}
|
||||
|
||||
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, scopes)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, preview)
|
||||
}
|
||||
|
||||
@@ -145,3 +145,9 @@ type AuthorizedOidcClientDto struct {
|
||||
Scope string `json:"scope"`
|
||||
Client OidcClientMetaDataDto `json:"client"`
|
||||
}
|
||||
|
||||
type OidcClientPreviewDto struct {
|
||||
IdToken map[string]interface{} `json:"idToken"`
|
||||
AccessToken map[string]interface{} `json:"accessToken"`
|
||||
UserInfo map[string]interface{} `json:"userInfo"`
|
||||
}
|
||||
|
||||
@@ -234,7 +234,8 @@ func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||
// BuildIDToken creates an ID token with all claims
|
||||
func (s *JwtService) BuildIDToken(userClaims map[string]any, clientID string, nonce string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Expiration(now.Add(1 * time.Hour)).
|
||||
@@ -242,33 +243,43 @@ func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string,
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetTokenType(token, IDTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range userClaims {
|
||||
err = token.Set(k, v)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
|
||||
return nil, fmt.Errorf("failed to set claim '%s': %w", k, err)
|
||||
}
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
err = token.Set("nonce", nonce)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
|
||||
return nil, fmt.Errorf("failed to set claim 'nonce': %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GenerateIDToken creates and signs an ID token
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
|
||||
token, err := s.BuildIDToken(userClaims, clientID, nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
|
||||
if err != nil {
|
||||
@@ -311,7 +322,8 @@ func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
// BuildOauthAccessToken creates an OAuth access token with all claims
|
||||
func (s *JwtService) BuildOauthAccessToken(user model.User, clientID string) (jwt.Token, error) {
|
||||
now := time.Now()
|
||||
token, err := jwt.NewBuilder().
|
||||
Subject(user.ID).
|
||||
@@ -320,17 +332,27 @@ func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string)
|
||||
Issuer(common.EnvConfig.AppURL).
|
||||
Build()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to build token: %w", err)
|
||||
return nil, fmt.Errorf("failed to build token: %w", err)
|
||||
}
|
||||
|
||||
err = SetAudienceString(token, clientID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
return nil, fmt.Errorf("failed to set 'aud' claim in token: %w", err)
|
||||
}
|
||||
|
||||
err = SetTokenType(token, OAuthAccessTokenJWTType)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
return nil, fmt.Errorf("failed to set 'type' claim in token: %w", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GenerateOauthAccessToken creates and signs an OAuth access token
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
token, err := s.BuildOauthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
alg, _ := s.privateKey.Algorithm()
|
||||
|
||||
@@ -841,97 +841,6 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("User.UserGroups").
|
||||
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := authorizedOidcClient.User
|
||||
scopes := strings.Split(authorizedOidcClient.Scope, " ")
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "groups") {
|
||||
userGroups := make([]string, len(user.UserGroups))
|
||||
for i, group := range user.UserGroups {
|
||||
userGroups[i] = group.Name
|
||||
}
|
||||
claims["groups"] = userGroups
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"name": user.FullName(),
|
||||
"preferred_username": user.Username,
|
||||
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, userID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
// The value of the custom claim can be a JSON object or a string
|
||||
var jsonValue interface{}
|
||||
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
|
||||
if err == nil {
|
||||
// It's JSON so we store it as an object
|
||||
claims[customClaim.Key] = jsonValue
|
||||
} else {
|
||||
// Marshalling failed, so we store it as a string
|
||||
claims[customClaim.Key] = customClaim.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateAllowedUserGroups(ctx context.Context, id string, input dto.OidcUpdateAllowedUserGroupsDto) (client model.OidcClient, err error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
@@ -1519,3 +1428,168 @@ func (s *OidcService) verifyClientAssertionFromFederatedIdentities(ctx context.C
|
||||
// If we're here, the assertion is valid
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
client, err := s.getClientInternal(ctx, clientID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Preload("UserGroups").
|
||||
First(&user, "id = ?", userID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !s.IsUserGroupAllowedToAuthorize(user, client) {
|
||||
return nil, &common.OidcAccessDeniedError{}
|
||||
}
|
||||
|
||||
dummyAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: clientID,
|
||||
Scope: scopes,
|
||||
User: user,
|
||||
}
|
||||
|
||||
userClaims, err := s.getUserClaimsFromAuthorizedClient(ctx, &dummyAuthorizedClient, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.BuildIDToken(userClaims, clientID, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.BuildOauthAccessToken(user, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idTokenPayload, err := utils.GetClaimsFromToken(idToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accessTokenPayload, err := utils.GetClaimsFromToken(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.OidcClientPreviewDto{
|
||||
IdToken: idTokenPayload,
|
||||
AccessToken: accessTokenPayload,
|
||||
UserInfo: userClaims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(ctx context.Context, userID string, clientID string) (map[string]interface{}, error) {
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
|
||||
claims, err := s.getUserClaimsForClientInternal(ctx, userID, clientID, s.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID string, clientID string, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Preload("User.UserGroups").
|
||||
First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.getUserClaimsFromAuthorizedClient(ctx, &authorizedOidcClient, tx)
|
||||
|
||||
}
|
||||
|
||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]interface{}, error) {
|
||||
user := authorizedClient.User
|
||||
scopes := strings.Split(authorizedClient.Scope, " ")
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "groups") {
|
||||
userGroups := make([]string, len(user.UserGroups))
|
||||
for i, group := range user.UserGroups {
|
||||
userGroups[i] = group.Name
|
||||
}
|
||||
claims["groups"] = userGroups
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"name": user.FullName(),
|
||||
"preferred_username": user.Username,
|
||||
"picture": common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png",
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "profile") {
|
||||
// Add profile claims
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
// Add custom claims
|
||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, customClaim := range customClaims {
|
||||
// The value of the custom claim can be a JSON object or a string
|
||||
var jsonValue interface{}
|
||||
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
|
||||
if err == nil {
|
||||
// It's JSON, so we store it as an object
|
||||
claims[customClaim.Key] = jsonValue
|
||||
} else {
|
||||
// Marshaling failed, so we store it as a string
|
||||
claims[customClaim.Key] = customClaim.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
19
backend/internal/utils/jwt_util.go
Normal file
19
backend/internal/utils/jwt_util.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
)
|
||||
|
||||
func GetClaimsFromToken(token jwt.Token) (map[string]any, error) {
|
||||
claims := make(map[string]any)
|
||||
for _, key := range token.Keys() {
|
||||
var value any
|
||||
if err := token.Get(key, &value); err != nil {
|
||||
return nil, fmt.Errorf("failed to get claim %s: %w", key, err)
|
||||
}
|
||||
claims[key] = value
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
Reference in New Issue
Block a user