feat: add database storage backend (#1091)

Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
Alessandro (Ale) Segala
2025-11-16 09:23:46 -08:00
committed by GitHub
parent 12125713a2
commit 29a1d3b778
28 changed files with 1491 additions and 258 deletions

View File

@@ -17,6 +17,7 @@ import (
"github.com/google/uuid"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -101,9 +102,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
profilePicturePath := path.Join("profile-pictures", userID+".png")
// Try custom profile picture
if file, size, err := s.fileStorage.Open(ctx, profilePicturePath); err == nil {
file, size, err := s.fileStorage.Open(ctx, profilePicturePath)
if err == nil {
return file, size, nil
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
} else if !errors.Is(err, fs.ErrNotExist) {
return nil, 0, err
}
@@ -120,9 +122,10 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
// Try cached default for initials
defaultPicturePath := path.Join("profile-pictures", "defaults", user.Initials()+".png")
if file, size, err := s.fileStorage.Open(ctx, defaultPicturePath); err == nil {
file, size, err = s.fileStorage.Open(ctx, defaultPicturePath)
if err == nil {
return file, size, nil
} else if err != nil && !errors.Is(err, fs.ErrNotExist) {
} else if !errors.Is(err, fs.ErrNotExist) {
return nil, 0, err
}
@@ -133,12 +136,13 @@ func (s *UserService) GetProfilePicture(ctx context.Context, userID string) (io.
}
// Save the default picture for future use (in a goroutine to avoid blocking)
//nolint:contextcheck
defaultPictureBytes := defaultPicture.Bytes()
//nolint:contextcheck
go func() {
if err := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes)); err != nil {
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", err))
// Use bytes.NewReader because we need an io.ReadSeeker
rErr := s.fileStorage.Save(context.Background(), defaultPicturePath, bytes.NewReader(defaultPictureBytes))
if rErr != nil {
slog.Error("Failed to cache default profile picture", slog.String("initials", user.Initials()), slog.Any("error", rErr))
}
}()
@@ -182,17 +186,30 @@ func (s *UserService) UpdateProfilePicture(ctx context.Context, userID string, f
}
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
return s.db.Transaction(func(tx *gorm.DB) error {
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
err := s.db.Transaction(func(tx *gorm.DB) error {
return s.deleteUserInternal(ctx, tx, userID, allowLdapDelete)
})
if err != nil {
return fmt.Errorf("failed to delete user '%s': %w", userID, err)
}
// Storage operations must be executed outside of a transaction
profilePicturePath := path.Join("profile-pictures", userID+".png")
err = s.fileStorage.Delete(ctx, profilePicturePath)
if err != nil && !storage.IsNotExist(err) {
return fmt.Errorf("failed to delete profile picture for user '%s': %w", userID, err)
}
return nil
}
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userID string, allowLdapDelete bool) error {
var user model.User
err := tx.
WithContext(ctx).
Where("id = ?", userID).
Clauses(clause.Locking{Strength: "UPDATE"}).
First(&user).
Error
if err != nil {
@@ -204,11 +221,6 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
return &common.LdapUserUpdateError{}
}
profilePicturePath := path.Join("profile-pictures", userID+".png")
if err := s.fileStorage.Delete(ctx, profilePicturePath); err != nil {
return err
}
err = tx.WithContext(ctx).Delete(&user).Error
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
@@ -286,16 +298,27 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
// Apply default user groups
var groupIDs []string
if v := config.SignupDefaultUserGroupIDs.Value; v != "" && v != "[]" {
if err := json.Unmarshal([]byte(v), &groupIDs); err != nil {
v := config.SignupDefaultUserGroupIDs.Value
if v != "" && v != "[]" {
err := json.Unmarshal([]byte(v), &groupIDs)
if err != nil {
return fmt.Errorf("invalid SignupDefaultUserGroupIDs JSON: %w", err)
}
if len(groupIDs) > 0 {
var groups []model.UserGroup
if err := tx.WithContext(ctx).Where("id IN ?", groupIDs).Find(&groups).Error; err != nil {
err = tx.WithContext(ctx).
Where("id IN ?", groupIDs).
Find(&groups).
Error
if err != nil {
return fmt.Errorf("failed to find default user groups: %w", err)
}
if err := tx.WithContext(ctx).Model(user).Association("UserGroups").Replace(groups); err != nil {
err = tx.WithContext(ctx).
Model(user).
Association("UserGroups").
Replace(groups)
if err != nil {
return fmt.Errorf("failed to associate default user groups: %w", err)
}
}
@@ -303,12 +326,15 @@ func (s *UserService) applySignupDefaults(ctx context.Context, user *model.User,
// Apply default custom claims
var claims []dto.CustomClaimCreateDto
if v := config.SignupDefaultCustomClaims.Value; v != "" && v != "[]" {
if err := json.Unmarshal([]byte(v), &claims); err != nil {
v = config.SignupDefaultCustomClaims.Value
if v != "" && v != "[]" {
err := json.Unmarshal([]byte(v), &claims)
if err != nil {
return fmt.Errorf("invalid SignupDefaultCustomClaims JSON: %w", err)
}
if len(claims) > 0 {
if _, err := s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx); err != nil {
_, err = s.customClaimService.updateCustomClaimsInternal(ctx, UserID, user.ID, claims, tx)
if err != nil {
return fmt.Errorf("failed to apply default custom claims: %w", err)
}
}
@@ -345,6 +371,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
err := tx.
WithContext(ctx).
Where("id = ?", userID).
Clauses(clause.Locking{Strength: "UPDATE"}).
First(&user).
Error
if err != nil {
@@ -416,13 +443,11 @@ func (s *UserService) RequestOneTimeAccessEmailAsUnauthenticatedUser(ctx context
var userId string
err := s.db.Model(&model.User{}).Select("id").Where("email = ?", userID).First(&userId).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Do not return error if user not found to prevent email enumeration
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
} else {
return err
}
return nil
} else if err != nil {
return err
}
return s.requestOneTimeAccessEmailInternal(ctx, userId, redirectPath, 15*time.Minute)
@@ -513,7 +538,9 @@ func (s *UserService) ExchangeOneTimeAccessToken(ctx context.Context, token stri
var oneTimeAccessToken model.OneTimeAccessToken
err := tx.
WithContext(ctx).
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).Preload("User").
Where("token = ? AND expires_at > ?", token, datatype.DateTime(time.Now())).
Preload("User").
Clauses(clause.Locking{Strength: "UPDATE"}).
First(&oneTimeAccessToken).
Error
if err != nil {
@@ -679,7 +706,7 @@ func (s *UserService) ResetProfilePicture(ctx context.Context, userID string) er
return nil
}
func (s *UserService) disableUserInternal(ctx context.Context, userID string, tx *gorm.DB) error {
func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, userID string) error {
return tx.
WithContext(ctx).
Model(&model.User{}).
@@ -720,6 +747,7 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
err := tx.
WithContext(ctx).
Where("token = ?", signupData.Token).
Clauses(clause.Locking{Strength: "UPDATE"}).
First(&signupToken).
Error
if err != nil {