mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-20 20:09:53 +00:00
Merge branch 'main' into feat/add_form_post
This commit is contained in:
@@ -2,6 +2,7 @@ package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/job"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
)
|
||||
@@ -60,7 +62,9 @@ func Bootstrap(ctx context.Context) error {
|
||||
}
|
||||
|
||||
waitUntil, err := svc.appLockService.Acquire(ctx, false)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrLockUnavailable) {
|
||||
return errors.New("it appears that there's already one instance of Pocket ID running; running multiple replicas of Pocket ID is currently not supported")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to acquire application lock: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -119,11 +119,10 @@ func acquireImportLock(ctx context.Context, db *gorm.DB, force bool) error {
|
||||
defer cancel()
|
||||
|
||||
waitUntil, err := appLockService.Acquire(opCtx, force)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrLockUnavailable) {
|
||||
//nolint:staticcheck
|
||||
return errors.New("Pocket ID must be stopped before importing data; please stop the running instance or run with --forcefully-acquire-lock to terminate the other instance")
|
||||
}
|
||||
if errors.Is(err, service.ErrLockUnavailable) {
|
||||
//nolint:staticcheck
|
||||
return errors.New("Pocket ID must be stopped before importing data; please stop the running instance or run with --forcefully-acquire-lock to terminate the other instance")
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to acquire application lock: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -34,6 +36,7 @@ func NewAppImagesController(
|
||||
group.PUT("/application-images/favicon", authMiddleware.Add(), controller.updateFaviconHandler)
|
||||
group.PUT("/application-images/default-profile-picture", authMiddleware.Add(), controller.updateDefaultProfilePicture)
|
||||
|
||||
group.DELETE("/application-images/background", authMiddleware.Add(), controller.deleteBackgroundImageHandler)
|
||||
group.DELETE("/application-images/default-profile-picture", authMiddleware.Add(), controller.deleteDefaultProfilePicture)
|
||||
}
|
||||
|
||||
@@ -192,12 +195,27 @@ func (c *AppImagesController) updateBackgroundImageHandler(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// deleteBackgroundImageHandler godoc
|
||||
// @Summary Delete background image
|
||||
// @Description Delete the application background image
|
||||
// @Tags Application Images
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/application-images/background [delete]
|
||||
func (c *AppImagesController) deleteBackgroundImageHandler(ctx *gin.Context) {
|
||||
if err := c.appImagesService.DeleteImage(ctx.Request.Context(), "background"); err != nil {
|
||||
_ = ctx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// updateFaviconHandler godoc
|
||||
// @Summary Update favicon
|
||||
// @Description Update the application favicon
|
||||
// @Tags Application Images
|
||||
// @Accept multipart/form-data
|
||||
// @Param file formData file true "Favicon file (.ico)"
|
||||
// @Param file formData file true "Favicon file (.svg/.png/.ico)"
|
||||
// @Success 204 "No Content"
|
||||
// @Router /api/application-images/favicon [put]
|
||||
func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
|
||||
@@ -208,8 +226,9 @@ func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if fileType != "ico" {
|
||||
_ = ctx.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
|
||||
mimeType := utils.GetImageMimeType(strings.ToLower(fileType))
|
||||
if !slices.Contains([]string{"image/svg+xml", "image/png", "image/x-icon"}, mimeType) {
|
||||
_ = ctx.Error(&common.WrongFileTypeError{ExpectedFileType: ".svg or .png or .ico"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ func (s *Scheduler) RegisterAnalyticsJob(ctx context.Context, appConfig *service
|
||||
appConfig: appConfig,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
return s.RegisterJob(ctx, "SendHeartbeat", gocron.DurationJob(24*time.Hour), jobs.sendHeartbeat, true)
|
||||
return s.RegisterJob(ctx, "SendHeartbeat", gocron.DurationJob(24*time.Hour), jobs.sendHeartbeat, service.RegisterJobOpts{RunImmediately: true})
|
||||
}
|
||||
|
||||
type AnalyticsJob struct {
|
||||
|
||||
@@ -22,7 +22,7 @@ func (s *Scheduler) RegisterApiKeyExpiryJob(ctx context.Context, apiKeyService *
|
||||
}
|
||||
|
||||
// Send every day at midnight
|
||||
return s.RegisterJob(ctx, "ExpiredApiKeyEmailJob", gocron.CronJob("0 0 * * *", false), jobs.checkAndNotifyExpiringApiKeys, false)
|
||||
return s.RegisterJob(ctx, "ExpiredApiKeyEmailJob", gocron.CronJob("0 0 * * *", false), jobs.checkAndNotifyExpiringApiKeys, service.RegisterJobOpts{})
|
||||
}
|
||||
|
||||
func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) error {
|
||||
@@ -42,7 +42,11 @@ func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) err
|
||||
}
|
||||
err = j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to send expiring API key notification email", slog.String("key", key.ID), slog.Any("error", err))
|
||||
slog.ErrorContext(ctx, "Failed to send expiring API key notification email",
|
||||
slog.String("key", key.ID),
|
||||
slog.String("user", key.User.ID),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -7,28 +7,37 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
backoff "github.com/cenkalti/backoff/v5"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
)
|
||||
|
||||
func (s *Scheduler) RegisterDbCleanupJobs(ctx context.Context, db *gorm.DB) error {
|
||||
jobs := &DbCleanupJobs{db: db}
|
||||
|
||||
// Run every 24 hours (but with some jitter so they don't run at the exact same time), and now
|
||||
def := gocron.DurationRandomJob(24*time.Hour-2*time.Minute, 24*time.Hour+2*time.Minute)
|
||||
newBackOff := func() *backoff.ExponentialBackOff {
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.Multiplier = 4
|
||||
bo.RandomizationFactor = 0.1
|
||||
bo.InitialInterval = time.Second
|
||||
bo.MaxInterval = 45 * time.Second
|
||||
return bo
|
||||
}
|
||||
|
||||
// Use exponential backoff for each DB cleanup job so transient query failures are retried automatically rather than causing an immediate job failure
|
||||
return errors.Join(
|
||||
s.RegisterJob(ctx, "ClearWebauthnSessions", def, jobs.clearWebauthnSessions, true),
|
||||
s.RegisterJob(ctx, "ClearOneTimeAccessTokens", def, jobs.clearOneTimeAccessTokens, true),
|
||||
s.RegisterJob(ctx, "ClearSignupTokens", def, jobs.clearSignupTokens, true),
|
||||
s.RegisterJob(ctx, "ClearEmailVerificationTokens", def, jobs.clearEmailVerificationTokens, true),
|
||||
s.RegisterJob(ctx, "ClearOidcAuthorizationCodes", def, jobs.clearOidcAuthorizationCodes, true),
|
||||
s.RegisterJob(ctx, "ClearOidcRefreshTokens", def, jobs.clearOidcRefreshTokens, true),
|
||||
s.RegisterJob(ctx, "ClearReauthenticationTokens", def, jobs.clearReauthenticationTokens, true),
|
||||
s.RegisterJob(ctx, "ClearAuditLogs", def, jobs.clearAuditLogs, true),
|
||||
s.RegisterJob(ctx, "ClearWebauthnSessions", jobDefWithJitter(24*time.Hour), jobs.clearWebauthnSessions, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearOneTimeAccessTokens", jobDefWithJitter(24*time.Hour), jobs.clearOneTimeAccessTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearSignupTokens", jobDefWithJitter(24*time.Hour), jobs.clearSignupTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearEmailVerificationTokens", jobDefWithJitter(24*time.Hour), jobs.clearEmailVerificationTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearOidcAuthorizationCodes", jobDefWithJitter(24*time.Hour), jobs.clearOidcAuthorizationCodes, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearOidcRefreshTokens", jobDefWithJitter(24*time.Hour), jobs.clearOidcRefreshTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearReauthenticationTokens", jobDefWithJitter(24*time.Hour), jobs.clearReauthenticationTokens, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
s.RegisterJob(ctx, "ClearAuditLogs", jobDefWithJitter(24*time.Hour), jobs.clearAuditLogs, service.RegisterJobOpts{RunImmediately: true, BackOff: newBackOff()}),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -13,20 +13,26 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
)
|
||||
|
||||
func (s *Scheduler) RegisterFileCleanupJobs(ctx context.Context, db *gorm.DB, fileStorage storage.FileStorage) error {
|
||||
jobs := &FileCleanupJobs{db: db, fileStorage: fileStorage}
|
||||
|
||||
err := s.RegisterJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, false)
|
||||
var errs []error
|
||||
errs = append(errs,
|
||||
s.RegisterJob(ctx, "ClearUnusedDefaultProfilePictures", gocron.DurationJob(24*time.Hour), jobs.clearUnusedDefaultProfilePictures, service.RegisterJobOpts{}),
|
||||
)
|
||||
|
||||
// Only necessary for file system storage
|
||||
if fileStorage.Type() == storage.TypeFileSystem {
|
||||
err = errors.Join(err, s.RegisterJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, true))
|
||||
errs = append(errs,
|
||||
s.RegisterJob(ctx, "ClearOrphanedTempFiles", gocron.DurationJob(12*time.Hour), jobs.clearOrphanedTempFiles, service.RegisterJobOpts{RunImmediately: true}),
|
||||
)
|
||||
}
|
||||
|
||||
return err
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
type FileCleanupJobs struct {
|
||||
@@ -68,7 +74,8 @@ func (j *FileCleanupJobs) clearUnusedDefaultProfilePictures(ctx context.Context)
|
||||
// If these initials aren't used by any user, delete the file
|
||||
if _, ok := initialsInUse[initials]; !ok {
|
||||
filePath := path.Join(defaultPicturesDir, filename)
|
||||
if err := j.fileStorage.Delete(ctx, filePath); err != nil {
|
||||
err = j.fileStorage.Delete(ctx, filePath)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to delete unused default profile picture", slog.String("path", filePath), slog.Any("error", err))
|
||||
} else {
|
||||
filesDeleted++
|
||||
@@ -95,8 +102,9 @@ func (j *FileCleanupJobs) clearOrphanedTempFiles(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := j.fileStorage.Delete(ctx, p.Path); err != nil {
|
||||
slog.ErrorContext(ctx, "Failed to delete temp file", slog.String("path", p.Path), slog.Any("error", err))
|
||||
rErr := j.fileStorage.Delete(ctx, p.Path)
|
||||
if rErr != nil {
|
||||
slog.ErrorContext(ctx, "Failed to delete temp file", slog.String("path", p.Path), slog.Any("error", rErr))
|
||||
return nil
|
||||
}
|
||||
deleted++
|
||||
|
||||
@@ -23,7 +23,7 @@ func (s *Scheduler) RegisterGeoLiteUpdateJobs(ctx context.Context, geoLiteServic
|
||||
jobs := &GeoLiteUpdateJobs{geoLiteService: geoLiteService}
|
||||
|
||||
// Run every 24 hours (and right away)
|
||||
return s.RegisterJob(ctx, "UpdateGeoLiteDB", gocron.DurationJob(24*time.Hour), jobs.updateGoeLiteDB, true)
|
||||
return s.RegisterJob(ctx, "UpdateGeoLiteDB", gocron.DurationJob(24*time.Hour), jobs.updateGoeLiteDB, service.RegisterJobOpts{RunImmediately: true})
|
||||
}
|
||||
|
||||
func (j *GeoLiteUpdateJobs) updateGoeLiteDB(ctx context.Context) error {
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
)
|
||||
|
||||
@@ -17,8 +15,8 @@ type LdapJobs struct {
|
||||
func (s *Scheduler) RegisterLdapJobs(ctx context.Context, ldapService *service.LdapService, appConfigService *service.AppConfigService) error {
|
||||
jobs := &LdapJobs{ldapService: ldapService, appConfigService: appConfigService}
|
||||
|
||||
// Register the job to run every hour
|
||||
return s.RegisterJob(ctx, "SyncLdap", gocron.DurationJob(time.Hour), jobs.syncLdap, true)
|
||||
// Register the job to run every hour (with some jitter)
|
||||
return s.RegisterJob(ctx, "SyncLdap", jobDefWithJitter(time.Hour), jobs.syncLdap, service.RegisterJobOpts{RunImmediately: true})
|
||||
}
|
||||
|
||||
func (j *LdapJobs) syncLdap(ctx context.Context) error {
|
||||
|
||||
@@ -5,9 +5,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v5"
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||
)
|
||||
|
||||
type Scheduler struct {
|
||||
@@ -33,16 +37,12 @@ func (s *Scheduler) RemoveJob(name string) error {
|
||||
if job.Name() == name {
|
||||
err := s.scheduler.RemoveJob(job.ID())
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to unqueue job %q with ID %q: %w", name, job.ID().String(), err))
|
||||
errs = append(errs, fmt.Errorf("failed to dequeue job %q with ID %q: %w", name, job.ID().String(), err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
// Run the scheduler.
|
||||
@@ -64,7 +64,29 @@ func (s *Scheduler) Run(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error {
|
||||
func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, jobFn func(ctx context.Context) error, opts service.RegisterJobOpts) error {
|
||||
// If a BackOff strategy is provided, wrap the job with retry logic
|
||||
if opts.BackOff != nil {
|
||||
origJob := jobFn
|
||||
jobFn = func(ctx context.Context) error {
|
||||
_, err := backoff.Retry(
|
||||
ctx,
|
||||
func() (struct{}, error) {
|
||||
return struct{}{}, origJob(ctx)
|
||||
},
|
||||
backoff.WithBackOff(opts.BackOff),
|
||||
backoff.WithNotify(func(err error, d time.Duration) {
|
||||
slog.WarnContext(ctx, "Job failed, retrying",
|
||||
slog.String("name", name),
|
||||
slog.Any("error", err),
|
||||
slog.Duration("retryIn", d),
|
||||
)
|
||||
}),
|
||||
)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
jobOptions := []gocron.JobOption{
|
||||
gocron.WithContext(ctx),
|
||||
gocron.WithName(name),
|
||||
@@ -91,13 +113,13 @@ func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.Job
|
||||
),
|
||||
}
|
||||
|
||||
if runImmediately {
|
||||
if opts.RunImmediately {
|
||||
jobOptions = append(jobOptions, gocron.JobOption(gocron.WithStartImmediately()))
|
||||
}
|
||||
|
||||
jobOptions = append(jobOptions, extraOptions...)
|
||||
jobOptions = append(jobOptions, opts.ExtraOptions...)
|
||||
|
||||
_, err := s.scheduler.NewJob(def, gocron.NewTask(job), jobOptions...)
|
||||
_, err := s.scheduler.NewJob(def, gocron.NewTask(jobFn), jobOptions...)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register job %q: %w", name, err)
|
||||
@@ -105,3 +127,9 @@ func (s *Scheduler) RegisterJob(ctx context.Context, name string, def gocron.Job
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func jobDefWithJitter(interval time.Duration) gocron.JobDefinition {
|
||||
const jitter = 5 * time.Minute
|
||||
|
||||
return gocron.DurationRandomJob(interval-jitter, interval+jitter)
|
||||
}
|
||||
|
||||
@@ -16,8 +16,8 @@ type ScimJobs struct {
|
||||
func (s *Scheduler) RegisterScimJobs(ctx context.Context, scimService *service.ScimService) error {
|
||||
jobs := &ScimJobs{scimService: scimService}
|
||||
|
||||
// Register the job to run every hour
|
||||
return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(time.Hour), jobs.SyncScim, true)
|
||||
// Register the job to run every hour (with some jitter)
|
||||
return s.RegisterJob(ctx, "SyncScim", gocron.DurationJob(time.Hour), jobs.SyncScim, service.RegisterJobOpts{RunImmediately: true})
|
||||
}
|
||||
|
||||
func (j *ScimJobs) SyncScim(ctx context.Context) error {
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||
@@ -205,36 +206,33 @@ func (s *ApiKeyService) ListExpiringApiKeys(ctx context.Context, daysAhead int)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey model.ApiKey) error {
|
||||
user := apiKey.User
|
||||
|
||||
if user.ID == "" {
|
||||
if err := s.db.WithContext(ctx).First(&user, "id = ?", apiKey.UserID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if user.Email == nil {
|
||||
if apiKey.User.Email == nil {
|
||||
return &common.UserEmailNotSetError{}
|
||||
}
|
||||
|
||||
err := SendEmail(ctx, s.emailService, email.Address{
|
||||
Name: user.FullName(),
|
||||
Email: *user.Email,
|
||||
Name: apiKey.User.FullName(),
|
||||
Email: *apiKey.User.Email,
|
||||
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
|
||||
ApiKeyName: apiKey.Name,
|
||||
ExpiresAt: apiKey.ExpiresAt.ToTime(),
|
||||
Name: user.FirstName,
|
||||
Name: apiKey.User.FirstName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error sending notification email: %w", err)
|
||||
}
|
||||
|
||||
// Mark the API key as having had an expiration email sent
|
||||
return s.db.WithContext(ctx).
|
||||
err = s.db.WithContext(ctx).
|
||||
Model(&model.ApiKey{}).
|
||||
Where("id = ?", apiKey.ID).
|
||||
Update("expiration_email_sent", true).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("error recording expiration sent email in database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) initStaticApiKeyUser(ctx context.Context) (user model.User, err error) {
|
||||
|
||||
@@ -73,7 +73,10 @@ func (lv *lockValue) Unmarshal(raw string) error {
|
||||
// Acquire obtains the lock. When force is true, the lock is stolen from any existing owner.
|
||||
// If the lock is forcefully acquired, it blocks until the previous lock has expired.
|
||||
func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil time.Time, err error) {
|
||||
tx := s.db.Begin()
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return time.Time{}, fmt.Errorf("begin lock transaction: %w", tx.Error)
|
||||
}
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
@@ -93,7 +96,8 @@ func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil tim
|
||||
|
||||
var prevLock lockValue
|
||||
if prevLockRaw != "" {
|
||||
if err := prevLock.Unmarshal(prevLockRaw); err != nil {
|
||||
err = prevLock.Unmarshal(prevLockRaw)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("decode existing lock value: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -139,7 +143,8 @@ func (s *AppLockService) Acquire(ctx context.Context, force bool) (waitUntil tim
|
||||
return time.Time{}, fmt.Errorf("lock acquisition failed: %w", res.Error)
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
err = tx.Commit().Error
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("commit lock acquisition: %w", err)
|
||||
}
|
||||
|
||||
@@ -174,7 +179,8 @@ func (s *AppLockService) RunRenewal(ctx context.Context) error {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := s.renew(ctx); err != nil {
|
||||
err := s.renew(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renew lock: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -183,33 +189,43 @@ func (s *AppLockService) RunRenewal(ctx context.Context) error {
|
||||
|
||||
// Release releases the lock if it is held by this process.
|
||||
func (s *AppLockService) Release(ctx context.Context) error {
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
db, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get DB connection: %w", err)
|
||||
}
|
||||
|
||||
var query string
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
DELETE FROM kv
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
`
|
||||
DELETE FROM kv
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
DELETE FROM kv
|
||||
WHERE key = $1
|
||||
AND value::json->>'lock_id' = $2
|
||||
`
|
||||
DELETE FROM kv
|
||||
WHERE key = $1
|
||||
AND value::json->>'lock_id' = $2
|
||||
`
|
||||
default:
|
||||
return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
res := s.db.WithContext(opCtx).Exec(query, lockKey, s.lockID)
|
||||
if res.Error != nil {
|
||||
return fmt.Errorf("release lock failed: %w", res.Error)
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := db.ExecContext(opCtx, query, lockKey, s.lockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release lock failed: %w", err)
|
||||
}
|
||||
|
||||
if res.RowsAffected == 0 {
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
slog.Warn("Application lock not held by this process, cannot release",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
@@ -225,6 +241,11 @@ func (s *AppLockService) Release(ctx context.Context) error {
|
||||
|
||||
// renew tries to renew the lock, retrying up to renewRetries times (sleeping 1s between attempts).
|
||||
func (s *AppLockService) renew(ctx context.Context) error {
|
||||
db, err := s.db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get DB connection: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= renewRetries; attempt++ {
|
||||
now := time.Now()
|
||||
@@ -246,42 +267,56 @@ func (s *AppLockService) renew(ctx context.Context) error {
|
||||
switch s.db.Name() {
|
||||
case "sqlite":
|
||||
query = `
|
||||
UPDATE kv
|
||||
SET value = ?
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
AND json_extract(value, '$.expires_at') > ?
|
||||
`
|
||||
UPDATE kv
|
||||
SET value = ?
|
||||
WHERE key = ?
|
||||
AND json_extract(value, '$.lock_id') = ?
|
||||
AND json_extract(value, '$.expires_at') > ?
|
||||
`
|
||||
case "postgres":
|
||||
query = `
|
||||
UPDATE kv
|
||||
SET value = $1
|
||||
WHERE key = $2
|
||||
AND value::json->>'lock_id' = $3
|
||||
AND ((value::json->>'expires_at')::bigint > $4)
|
||||
`
|
||||
UPDATE kv
|
||||
SET value = $1
|
||||
WHERE key = $2
|
||||
AND value::json->>'lock_id' = $3
|
||||
AND ((value::json->>'expires_at')::bigint > $4)
|
||||
`
|
||||
default:
|
||||
return fmt.Errorf("unsupported database dialect: %s", s.db.Name())
|
||||
}
|
||||
|
||||
opCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
res := s.db.WithContext(opCtx).Exec(query, raw, lockKey, s.lockID, nowUnix)
|
||||
res, err := db.ExecContext(opCtx, query, raw, lockKey, s.lockID, nowUnix)
|
||||
cancel()
|
||||
|
||||
switch {
|
||||
case res.Error != nil:
|
||||
lastErr = fmt.Errorf("lock renewal failed: %w", res.Error)
|
||||
case res.RowsAffected == 0:
|
||||
// Must be after checking res.Error
|
||||
return ErrLockLost
|
||||
default:
|
||||
// Query succeeded, but may have updated 0 rows
|
||||
if err == nil {
|
||||
count, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to count affected rows: %w", err)
|
||||
}
|
||||
|
||||
// If no rows were updated, we lost the lock
|
||||
if count == 0 {
|
||||
return ErrLockLost
|
||||
}
|
||||
|
||||
// All good
|
||||
slog.Debug("Renewed application lock",
|
||||
slog.Int64("process_id", s.processID),
|
||||
slog.String("host_id", s.hostID),
|
||||
slog.Duration("duration", time.Since(now)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we're here, we have an error that can be retried
|
||||
slog.Debug("Application lock renewal attempt failed",
|
||||
slog.Any("error", err),
|
||||
slog.Duration("duration", time.Since(now)),
|
||||
)
|
||||
lastErr = fmt.Errorf("lock renewal failed: %w", err)
|
||||
|
||||
// Wait before next attempt or cancel if context is done
|
||||
if attempt < renewRetries {
|
||||
select {
|
||||
|
||||
@@ -49,6 +49,23 @@ func readLockValue(t *testing.T, db *gorm.DB) lockValue {
|
||||
return value
|
||||
}
|
||||
|
||||
func lockDatabaseForWrite(t *testing.T, db *gorm.DB) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
tx := db.Begin()
|
||||
require.NoError(t, tx.Error)
|
||||
|
||||
// Keep a write transaction open to block other queries.
|
||||
err := tx.Exec(
|
||||
`INSERT INTO kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO NOTHING`,
|
||||
lockKey,
|
||||
`{"expires_at":0}`,
|
||||
).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
func TestAppLockServiceAcquire(t *testing.T) {
|
||||
t.Run("creates new lock when none exists", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
@@ -99,6 +116,66 @@ func TestAppLockServiceAcquire(t *testing.T) {
|
||||
require.Equal(t, service.hostID, stored.HostID)
|
||||
require.Greater(t, stored.ExpiresAt, time.Now().Unix())
|
||||
})
|
||||
|
||||
t.Run("force acquisition returns wait duration when stealing active lock", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
existing := lockValue{
|
||||
ProcessID: 99,
|
||||
HostID: "other-host",
|
||||
LockID: "other-lock-id",
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
}
|
||||
insertLock(t, db, existing)
|
||||
|
||||
waitUntil, err := service.Acquire(context.Background(), true)
|
||||
require.NoError(t, err)
|
||||
require.WithinDuration(t, time.Unix(existing.ExpiresAt, 0), waitUntil, time.Second)
|
||||
})
|
||||
|
||||
t.Run("force acquisition does not wait when lock id is unchanged", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
insertLock(t, db, lockValue{
|
||||
ProcessID: 99,
|
||||
HostID: "other-host",
|
||||
LockID: service.lockID,
|
||||
ExpiresAt: time.Now().Add(ttl).Unix(),
|
||||
})
|
||||
|
||||
waitUntil, err := service.Acquire(context.Background(), true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, waitUntil.IsZero())
|
||||
})
|
||||
|
||||
t.Run("returns error when existing lock value is invalid JSON", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
raw := "this-is-not-json"
|
||||
err := db.Create(&model.KV{Key: lockKey, Value: &raw}).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = service.Acquire(context.Background(), false)
|
||||
require.ErrorContains(t, err, "decode existing lock value")
|
||||
})
|
||||
|
||||
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
tx := lockDatabaseForWrite(t, db)
|
||||
defer tx.Rollback()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := service.Acquire(ctx, false)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorContains(t, err, "begin lock transaction")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppLockServiceRelease(t *testing.T) {
|
||||
@@ -134,6 +211,24 @@ func TestAppLockServiceRelease(t *testing.T) {
|
||||
stored := readLockValue(t, db)
|
||||
require.Equal(t, existing, stored)
|
||||
})
|
||||
|
||||
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx := lockDatabaseForWrite(t, db)
|
||||
defer tx.Rollback()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = service.Release(ctx)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.ErrorContains(t, err, "release lock failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppLockServiceRenew(t *testing.T) {
|
||||
@@ -186,4 +281,21 @@ func TestAppLockServiceRenew(t *testing.T) {
|
||||
err = service.renew(context.Background())
|
||||
require.ErrorIs(t, err, ErrLockLost)
|
||||
})
|
||||
|
||||
t.Run("returns context deadline exceeded when database is locked", func(t *testing.T) {
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
service := newTestAppLockService(t, db)
|
||||
|
||||
_, err := service.Acquire(context.Background(), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
tx := lockDatabaseForWrite(t, db)
|
||||
defer tx.Rollback()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = service.renew(ctx)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -258,7 +258,7 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
||||
Nonce: "nonce",
|
||||
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
|
||||
UserID: users[1].ID,
|
||||
ClientID: oidcClients[2].ID,
|
||||
ClientID: oidcClients[3].ID,
|
||||
},
|
||||
}
|
||||
for _, authCode := range authCodes {
|
||||
|
||||
@@ -150,7 +150,8 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
|
||||
}
|
||||
|
||||
// Send the email
|
||||
if err := srv.sendEmailContent(client, toEmail, c); err != nil {
|
||||
err = srv.sendEmailContent(client, toEmail, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send email content: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ type LdapService struct {
|
||||
userService *UserService
|
||||
groupService *UserGroupService
|
||||
fileStorage storage.FileStorage
|
||||
clientFactory func() (ldapClient, error)
|
||||
}
|
||||
|
||||
type savePicture struct {
|
||||
@@ -43,8 +44,33 @@ type savePicture struct {
|
||||
picture string
|
||||
}
|
||||
|
||||
type ldapDesiredUser struct {
|
||||
ldapID string
|
||||
input dto.UserCreateDto
|
||||
picture string
|
||||
}
|
||||
|
||||
type ldapDesiredGroup struct {
|
||||
ldapID string
|
||||
input dto.UserGroupCreateDto
|
||||
memberUsernames []string
|
||||
}
|
||||
|
||||
type ldapDesiredState struct {
|
||||
users []ldapDesiredUser
|
||||
userIDs map[string]struct{}
|
||||
groups []ldapDesiredGroup
|
||||
groupIDs map[string]struct{}
|
||||
}
|
||||
|
||||
type ldapClient interface {
|
||||
Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
|
||||
Bind(username, password string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppConfigService, userService *UserService, groupService *UserGroupService, fileStorage storage.FileStorage) *LdapService {
|
||||
return &LdapService{
|
||||
service := &LdapService{
|
||||
db: db,
|
||||
httpClient: httpClient,
|
||||
appConfigService: appConfigService,
|
||||
@@ -52,9 +78,12 @@ func NewLdapService(db *gorm.DB, httpClient *http.Client, appConfigService *AppC
|
||||
groupService: groupService,
|
||||
fileStorage: fileStorage,
|
||||
}
|
||||
|
||||
service.clientFactory = service.createClient
|
||||
return service
|
||||
}
|
||||
|
||||
func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
func (s *LdapService) createClient() (ldapClient, error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
if !dbConfig.LdapEnabled.IsTrue() {
|
||||
@@ -79,24 +108,33 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
|
||||
|
||||
func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
// Setup LDAP connection
|
||||
client, err := s.createClient()
|
||||
client, err := s.clientFactory()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create LDAP client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Start a transaction
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
tx.Rollback()
|
||||
}()
|
||||
// First, we fetch all users and group from LDAP, which is our "desired state"
|
||||
desiredState, err := s.fetchDesiredState(ctx, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch LDAP state: %w", err)
|
||||
}
|
||||
|
||||
savePictures, deleteFiles, err := s.SyncUsers(ctx, tx, client)
|
||||
// Start a transaction
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("failed to begin database transaction: %w", tx.Error)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Reconcile users
|
||||
savePictures, deleteFiles, err := s.reconcileUsers(ctx, tx, desiredState.users, desiredState.userIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users: %w", err)
|
||||
}
|
||||
|
||||
err = s.SyncGroups(ctx, tx, client)
|
||||
// Reconcile groups
|
||||
err = s.reconcileGroups(ctx, tx, desiredState.groups, desiredState.groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync groups: %w", err)
|
||||
}
|
||||
@@ -129,10 +167,31 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.Conn) error {
|
||||
func (s *LdapService) fetchDesiredState(ctx context.Context, client ldapClient) (ldapDesiredState, error) {
|
||||
// Fetch users first so we can use their DNs when resolving group members
|
||||
users, userIDs, usernamesByDN, err := s.fetchUsersFromLDAP(ctx, client)
|
||||
if err != nil {
|
||||
return ldapDesiredState{}, err
|
||||
}
|
||||
|
||||
// Then fetch groups to complete the desired LDAP state snapshot
|
||||
groups, groupIDs, err := s.fetchGroupsFromLDAP(ctx, client, usernamesByDN)
|
||||
if err != nil {
|
||||
return ldapDesiredState{}, err
|
||||
}
|
||||
|
||||
return ldapDesiredState{
|
||||
users: users,
|
||||
userIDs: userIDs,
|
||||
groups: groups,
|
||||
groupIDs: groupIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) fetchGroupsFromLDAP(ctx context.Context, client ldapClient, usernamesByDN map[string]string) (desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Query LDAP for all groups we want to manage
|
||||
searchAttrs := []string{
|
||||
dbConfig.LdapAttributeGroupName.Value,
|
||||
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
|
||||
@@ -149,90 +208,42 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
)
|
||||
result, err := client.Search(searchReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query LDAP: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to query LDAP groups: %w", err)
|
||||
}
|
||||
|
||||
// Create a mapping for groups that exist
|
||||
ldapGroupIDs := make(map[string]struct{}, len(result.Entries))
|
||||
// Build the in-memory desired state for groups
|
||||
ldapGroupIDs = make(map[string]struct{}, len(result.Entries))
|
||||
desiredGroups = make([]ldapDesiredGroup, 0, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
|
||||
// Skip groups without a valid LDAP ID
|
||||
if ldapId == "" {
|
||||
if ldapID == "" {
|
||||
slog.Warn("Skipping LDAP group without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeGroupUniqueIdentifier.Value))
|
||||
continue
|
||||
}
|
||||
|
||||
ldapGroupIDs[ldapId] = struct{}{}
|
||||
|
||||
// Try to find the group in the database
|
||||
var databaseGroup model.UserGroup
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseGroup).
|
||||
Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return fmt.Errorf("failed to query for LDAP group ID '%s': %w", ldapId, err)
|
||||
}
|
||||
ldapGroupIDs[ldapID] = struct{}{}
|
||||
|
||||
// Get group members and add to the correct Group
|
||||
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
|
||||
membersUserId := make([]string, 0, len(groupMembers))
|
||||
memberUsernames := make([]string, 0, len(groupMembers))
|
||||
for _, member := range groupMembers {
|
||||
username := getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
|
||||
|
||||
// If username extraction fails, try to query LDAP directly for the user
|
||||
username := s.resolveGroupMemberUsername(ctx, client, member, usernamesByDN)
|
||||
if username == "" {
|
||||
// Query LDAP to get the user by their DN
|
||||
userSearchReq := ldap.NewSearchRequest(
|
||||
member,
|
||||
ldap.ScopeBaseObject,
|
||||
0, 0, 0, false,
|
||||
"(objectClass=*)",
|
||||
[]string{dbConfig.LdapAttributeUserUsername.Value, dbConfig.LdapAttributeUserUniqueIdentifier.Value},
|
||||
[]ldap.Control{},
|
||||
)
|
||||
|
||||
userResult, err := client.Search(userSearchReq)
|
||||
if err != nil || len(userResult.Entries) == 0 {
|
||||
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
||||
continue
|
||||
}
|
||||
|
||||
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
||||
if username == "" {
|
||||
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
username = norm.NFC.String(username)
|
||||
|
||||
var databaseUser model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("username = ? AND ldap_id IS NOT NULL", username).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// The user collides with a non-LDAP user, so we skip it
|
||||
continue
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to query for existing user '%s': %w", username, err)
|
||||
}
|
||||
|
||||
membersUserId = append(membersUserId, databaseUser.ID)
|
||||
memberUsernames = append(memberUsernames, username)
|
||||
}
|
||||
|
||||
syncGroup := dto.UserGroupCreateDto{
|
||||
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
|
||||
LdapID: ldapId,
|
||||
LdapID: ldapID,
|
||||
}
|
||||
dto.Normalize(syncGroup)
|
||||
dto.Normalize(&syncGroup)
|
||||
|
||||
err = syncGroup.Validate()
|
||||
if err != nil {
|
||||
@@ -240,64 +251,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
||||
continue
|
||||
}
|
||||
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
} else {
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, syncGroup, true, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, membersUserId, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", syncGroup.Name, err)
|
||||
}
|
||||
}
|
||||
desiredGroups = append(desiredGroups, ldapDesiredGroup{
|
||||
ldapID: ldapID,
|
||||
input: syncGroup,
|
||||
memberUsernames: memberUsernames,
|
||||
})
|
||||
}
|
||||
|
||||
// Get all LDAP groups from the database
|
||||
var ldapGroupsInDb []model.UserGroup
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
||||
}
|
||||
|
||||
// Delete groups that no longer exist in LDAP
|
||||
for _, group := range ldapGroupsInDb {
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted group", slog.String("group", group.Name))
|
||||
}
|
||||
|
||||
return nil
|
||||
return desiredGroups, ldapGroupIDs, nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.Conn) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||
func (s *LdapService) fetchUsersFromLDAP(ctx context.Context, client ldapClient) (desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}, usernamesByDN map[string]string, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Query LDAP for all users we want to manage
|
||||
searchAttrs := []string{
|
||||
"memberOf",
|
||||
"sn",
|
||||
@@ -323,50 +290,29 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
|
||||
result, err := client.Search(searchReq)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to query LDAP: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to query LDAP users: %w", err)
|
||||
}
|
||||
|
||||
// Create a mapping for users that exist
|
||||
ldapUserIDs := make(map[string]struct{}, len(result.Entries))
|
||||
savePictures = make([]savePicture, 0, len(result.Entries))
|
||||
// Build the in-memory desired state for users and a DN lookup for group membership resolution
|
||||
ldapUserIDs = make(map[string]struct{}, len(result.Entries))
|
||||
usernamesByDN = make(map[string]string, len(result.Entries))
|
||||
desiredUsers = make([]ldapDesiredUser, 0, len(result.Entries))
|
||||
|
||||
for _, value := range result.Entries {
|
||||
ldapId := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
username := norm.NFC.String(value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value))
|
||||
if normalizedDN := normalizeLDAPDN(value.DN); normalizedDN != "" && username != "" {
|
||||
usernamesByDN[normalizedDN] = username
|
||||
}
|
||||
|
||||
ldapID := convertLdapIdToString(value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
|
||||
// Skip users without a valid LDAP ID
|
||||
if ldapId == "" {
|
||||
if ldapID == "" {
|
||||
slog.Warn("Skipping LDAP user without a valid unique identifier", slog.String("attribute", dbConfig.LdapAttributeUserUniqueIdentifier.Value))
|
||||
continue
|
||||
}
|
||||
|
||||
ldapUserIDs[ldapId] = struct{}{}
|
||||
|
||||
// Get the user from the database
|
||||
var databaseUser model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
|
||||
// If a user is found (even if disabled), enable them since they're now back in LDAP
|
||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", databaseUser.ID).
|
||||
Update("disabled", false).
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return nil, nil, fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||
}
|
||||
ldapUserIDs[ldapID] = struct{}{}
|
||||
|
||||
// Check if user is admin by checking if they are in the admin group
|
||||
isAdmin := false
|
||||
@@ -385,14 +331,14 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
|
||||
DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value),
|
||||
IsAdmin: isAdmin,
|
||||
LdapID: ldapId,
|
||||
LdapID: ldapID,
|
||||
}
|
||||
|
||||
if newUser.DisplayName == "" {
|
||||
newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName)
|
||||
}
|
||||
|
||||
dto.Normalize(newUser)
|
||||
dto.Normalize(&newUser)
|
||||
|
||||
err = newUser.Validate()
|
||||
if err != nil {
|
||||
@@ -400,53 +346,201 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
continue
|
||||
}
|
||||
|
||||
userID := databaseUser.ID
|
||||
if databaseUser.ID == "" {
|
||||
createdUser, err := s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping creating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
desiredUsers = append(desiredUsers, ldapDesiredUser{
|
||||
ldapID: ldapID,
|
||||
input: newUser,
|
||||
picture: value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value),
|
||||
})
|
||||
}
|
||||
|
||||
return desiredUsers, ldapUserIDs, usernamesByDN, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) resolveGroupMemberUsername(ctx context.Context, client ldapClient, member string, usernamesByDN map[string]string) string {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// First try the DN cache we built while loading users
|
||||
username, exists := usernamesByDN[normalizeLDAPDN(member)]
|
||||
if exists && username != "" {
|
||||
return username
|
||||
}
|
||||
|
||||
// Then try to extract the username directly from the DN
|
||||
username = getDNProperty(dbConfig.LdapAttributeUserUsername.Value, member)
|
||||
if username != "" {
|
||||
return norm.NFC.String(username)
|
||||
}
|
||||
|
||||
// As a fallback, query LDAP for the referenced entry
|
||||
userSearchReq := ldap.NewSearchRequest(
|
||||
member,
|
||||
ldap.ScopeBaseObject,
|
||||
0, 0, 0, false,
|
||||
"(objectClass=*)",
|
||||
[]string{dbConfig.LdapAttributeUserUsername.Value},
|
||||
[]ldap.Control{},
|
||||
)
|
||||
|
||||
userResult, err := client.Search(userSearchReq)
|
||||
if err != nil || len(userResult.Entries) == 0 {
|
||||
slog.WarnContext(ctx, "Could not resolve group member DN", slog.String("member", member), slog.Any("error", err))
|
||||
return ""
|
||||
}
|
||||
|
||||
username = userResult.Entries[0].GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value)
|
||||
if username == "" {
|
||||
slog.WarnContext(ctx, "Could not extract username from group member DN", slog.String("member", member))
|
||||
return ""
|
||||
}
|
||||
|
||||
return norm.NFC.String(username)
|
||||
}
|
||||
|
||||
func (s *LdapService) reconcileGroups(ctx context.Context, tx *gorm.DB, desiredGroups []ldapDesiredGroup, ldapGroupIDs map[string]struct{}) error {
|
||||
// Load the current LDAP-managed state from the database
|
||||
ldapGroupsInDB, ldapGroupsByID, err := s.loadLDAPGroupsInDB(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch groups from database: %w", err)
|
||||
}
|
||||
|
||||
_, _, ldapUsersByUsername, err := s.loadLDAPUsersInDB(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Apply creates and updates to match the desired LDAP group state
|
||||
for _, desiredGroup := range desiredGroups {
|
||||
memberUserIDs := make([]string, 0, len(desiredGroup.memberUsernames))
|
||||
for _, username := range desiredGroup.memberUsernames {
|
||||
databaseUser, exists := ldapUsersByUsername[username]
|
||||
if !exists {
|
||||
// The user collides with a non-LDAP user or was skipped during user sync, so we ignore it
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
userID = createdUser.ID
|
||||
} else {
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, newUser, false, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping updating LDAP user", slog.String("username", newUser.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error updating user '%s': %w", newUser.Username, err)
|
||||
}
|
||||
|
||||
memberUserIDs = append(memberUserIDs, databaseUser.ID)
|
||||
}
|
||||
|
||||
// Save profile picture
|
||||
pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value)
|
||||
if pictureString != "" {
|
||||
// Storage operations must be executed outside of a transaction
|
||||
savePictures = append(savePictures, savePicture{
|
||||
userID: databaseUser.ID,
|
||||
username: userID,
|
||||
picture: pictureString,
|
||||
})
|
||||
databaseGroup := ldapGroupsByID[desiredGroup.ldapID]
|
||||
if databaseGroup.ID == "" {
|
||||
newGroup, err := s.groupService.createInternal(ctx, desiredGroup.input, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
ldapGroupsByID[desiredGroup.ldapID] = newGroup
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, newGroup.ID, memberUserIDs, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateInternal(ctx, databaseGroup.ID, desiredGroup.input, true, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
|
||||
_, err = s.groupService.updateUsersInternal(ctx, databaseGroup.ID, memberUserIDs, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sync users for group '%s': %w", desiredGroup.input.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all LDAP users from the database
|
||||
var ldapUsersInDb []model.User
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
|
||||
Select("id, username, ldap_id, disabled").
|
||||
Error
|
||||
// Delete groups that are no longer present in LDAP
|
||||
for _, group := range ldapGroupsInDB {
|
||||
if group.LdapID == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := ldapGroupIDs[*group.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Delete(&model.UserGroup{}, "ldap_id = ?", *group.LdapID).
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete group '%s': %w", group.Name, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted group", slog.String("group", group.Name))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:gocognit
|
||||
func (s *LdapService) reconcileUsers(ctx context.Context, tx *gorm.DB, desiredUsers []ldapDesiredUser, ldapUserIDs map[string]struct{}) (savePictures []savePicture, deleteFiles []string, err error) {
|
||||
dbConfig := s.appConfigService.GetDbConfig()
|
||||
|
||||
// Load the current LDAP-managed state from the database
|
||||
ldapUsersInDB, ldapUsersByID, _, err := s.loadLDAPUsersInDB(ctx, tx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Mark users as disabled or delete users that no longer exist in LDAP
|
||||
deleteFiles = make([]string, 0, len(ldapUserIDs))
|
||||
for _, user := range ldapUsersInDb {
|
||||
// Skip if the user ID exists in the fetched LDAP results
|
||||
// Apply creates and updates to match the desired LDAP user state
|
||||
savePictures = make([]savePicture, 0, len(desiredUsers))
|
||||
|
||||
for _, desiredUser := range desiredUsers {
|
||||
databaseUser := ldapUsersByID[desiredUser.ldapID]
|
||||
|
||||
// If a user is found (even if disabled), enable them since they're now back in LDAP.
|
||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", databaseUser.ID).
|
||||
Update("disabled", false).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to enable user %s: %w", databaseUser.Username, err)
|
||||
}
|
||||
|
||||
databaseUser.Disabled = false
|
||||
ldapUsersByID[desiredUser.ldapID] = databaseUser
|
||||
}
|
||||
|
||||
userID := databaseUser.ID
|
||||
if databaseUser.ID == "" {
|
||||
createdUser, err := s.userService.createUserInternal(ctx, desiredUser.input, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping creating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating user '%s': %w", desiredUser.input.Username, err)
|
||||
}
|
||||
|
||||
userID = createdUser.ID
|
||||
ldapUsersByID[desiredUser.ldapID] = createdUser
|
||||
} else {
|
||||
_, err = s.userService.updateUserInternal(ctx, databaseUser.ID, desiredUser.input, false, true, tx)
|
||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||
slog.Warn("Skipping updating LDAP user", slog.String("username", desiredUser.input.Username), slog.Any("error", err))
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, nil, fmt.Errorf("error updating user '%s': %w", desiredUser.input.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
if desiredUser.picture != "" {
|
||||
savePictures = append(savePictures, savePicture{
|
||||
userID: userID,
|
||||
username: desiredUser.input.Username,
|
||||
picture: desiredUser.picture,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Disable or delete users that are no longer present in LDAP
|
||||
deleteFiles = make([]string, 0, len(ldapUsersInDB))
|
||||
for _, user := range ldapUsersInDB {
|
||||
if user.LdapID == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
@@ -458,29 +552,73 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
}
|
||||
|
||||
slog.Info("Disabled user", slog.String("username", user.Username))
|
||||
} else {
|
||||
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||
if err != nil {
|
||||
target := &common.LdapUserUpdateError{}
|
||||
if errors.As(err, &target) {
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||
|
||||
// Storage operations must be executed outside of a transaction
|
||||
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||
continue
|
||||
}
|
||||
|
||||
err = s.userService.deleteUserInternal(ctx, tx, user.ID, true)
|
||||
if err != nil {
|
||||
target := &common.LdapUserUpdateError{}
|
||||
if errors.As(err, &target) {
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: LDAP user must be disabled before deletion", user.Username)
|
||||
}
|
||||
return nil, nil, fmt.Errorf("failed to delete user %s: %w", user.Username, err)
|
||||
}
|
||||
|
||||
slog.Info("Deleted user", slog.String("username", user.Username))
|
||||
deleteFiles = append(deleteFiles, path.Join("profile-pictures", user.ID+".png"))
|
||||
}
|
||||
|
||||
return savePictures, deleteFiles, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) loadLDAPUsersInDB(ctx context.Context, tx *gorm.DB) (users []model.User, byLdapID map[string]model.User, byUsername map[string]model.User, err error) {
|
||||
// Load all LDAP-managed users and index them by LDAP ID and by username
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Select("id, username, ldap_id, disabled").
|
||||
Where("ldap_id IS NOT NULL").
|
||||
Find(&users).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
byLdapID = make(map[string]model.User, len(users))
|
||||
byUsername = make(map[string]model.User, len(users))
|
||||
for _, user := range users {
|
||||
byLdapID[*user.LdapID] = user
|
||||
byUsername[user.Username] = user
|
||||
}
|
||||
|
||||
return users, byLdapID, byUsername, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) loadLDAPGroupsInDB(ctx context.Context, tx *gorm.DB) ([]model.UserGroup, map[string]model.UserGroup, error) {
|
||||
var groups []model.UserGroup
|
||||
|
||||
// Load all LDAP-managed groups and index them by LDAP ID
|
||||
err := tx.
|
||||
WithContext(ctx).
|
||||
Select("id, name, ldap_id").
|
||||
Where("ldap_id IS NOT NULL").
|
||||
Find(&groups).
|
||||
Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupsByID := make(map[string]model.UserGroup, len(groups))
|
||||
for _, group := range groups {
|
||||
groupsByID[*group.LdapID] = group
|
||||
}
|
||||
|
||||
return groups, groupsByID, nil
|
||||
}
|
||||
|
||||
func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId string, pictureString string) error {
|
||||
var reader io.ReadSeeker
|
||||
|
||||
// Accept either a URL, a base64-encoded payload, or raw binary data
|
||||
_, err := url.ParseRequestURI(pictureString)
|
||||
if err == nil {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||
@@ -522,6 +660,31 @@ func (s *LdapService) saveProfilePicture(parentCtx context.Context, userId strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeLDAPDN returns a canonical lowercase form of a DN for use as a map key.
|
||||
// Different LDAP servers may format the same DN with varying attribute type casing (e.g. "CN=" vs "cn=") or extra whitespace (e.g. "dc=example, dc=com").
|
||||
// Without normalization, cache lookups in usernamesByDN would miss when a member attribute value uses a different format than the DN returned in the search entry
|
||||
//
|
||||
// ldap.ParseDN is used instead of simple lowercasing because it correctly handles multi-valued RDNs (joined with "+") and strips inter-component whitespace.
|
||||
// If parsing fails for any reason, we fall back to a simple lowercase+trim.
|
||||
func normalizeLDAPDN(dn string) string {
|
||||
parsed, err := ldap.ParseDN(dn)
|
||||
if err != nil {
|
||||
return strings.ToLower(strings.TrimSpace(dn))
|
||||
}
|
||||
|
||||
// Reconstruct the DN in a canonical form: lowercase type=lowercase value, with RDN components separated by "," and multi-value attributes by "+"
|
||||
parts := make([]string, 0, len(parsed.RDNs))
|
||||
for _, rdn := range parsed.RDNs {
|
||||
attrs := make([]string, 0, len(rdn.Attributes))
|
||||
for _, attr := range rdn.Attributes {
|
||||
attrs = append(attrs, strings.ToLower(attr.Type)+"="+strings.ToLower(attr.Value))
|
||||
}
|
||||
parts = append(parts, strings.Join(attrs, "+"))
|
||||
}
|
||||
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// getDNProperty returns the value of a property from a LDAP identifier
|
||||
// See: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ldap/distinguished-names
|
||||
func getDNProperty(property string, str string) string {
|
||||
|
||||
@@ -1,9 +1,286 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||
"github.com/pocket-id/pocket-id/backend/internal/storage"
|
||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||
)
|
||||
|
||||
type fakeLDAPClient struct {
|
||||
searchFn func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
|
||||
}
|
||||
|
||||
func (c *fakeLDAPClient) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||
if c.searchFn == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return c.searchFn(searchRequest)
|
||||
}
|
||||
|
||||
func (c *fakeLDAPClient) Bind(_, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeLDAPClient) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestLdapServiceSyncAllReconcilesUsersAndGroups(t *testing.T) {
|
||||
service, db := newTestLdapService(t, newFakeLDAPClient(
|
||||
ldapSearchResult(
|
||||
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-alice"},
|
||||
"uid": {"alice"},
|
||||
"mail": {"alice@example.com"},
|
||||
"givenName": {"Alice"},
|
||||
"sn": {"Jones"},
|
||||
"displayName": {""},
|
||||
"memberOf": {"cn=admins,ou=groups,dc=example,dc=com"},
|
||||
}),
|
||||
ldapEntry("uid=bob,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-bob"},
|
||||
"uid": {"bob"},
|
||||
"mail": {"bob@example.com"},
|
||||
"givenName": {"Bob"},
|
||||
"sn": {"Brown"},
|
||||
"displayName": {""},
|
||||
}),
|
||||
),
|
||||
ldapSearchResult(
|
||||
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"g-team"},
|
||||
"cn": {"team"},
|
||||
"member": {
|
||||
"UID=Alice, OU=People, DC=example, DC=com",
|
||||
"uid=bob, ou=people, dc=example, dc=com",
|
||||
},
|
||||
}),
|
||||
),
|
||||
))
|
||||
|
||||
aliceLdapID := "u-alice"
|
||||
missingLdapID := "u-missing"
|
||||
teamLdapID := "g-team"
|
||||
oldGroupLdapID := "g-old"
|
||||
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Username: "alice-old",
|
||||
Email: new("alice-old@example.com"),
|
||||
EmailVerified: true,
|
||||
FirstName: "Old",
|
||||
LastName: "Name",
|
||||
DisplayName: "Old Name",
|
||||
LdapID: &aliceLdapID,
|
||||
Disabled: true,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&model.User{
|
||||
Username: "missing",
|
||||
Email: new("missing@example.com"),
|
||||
EmailVerified: true,
|
||||
FirstName: "Missing",
|
||||
LastName: "User",
|
||||
DisplayName: "Missing User",
|
||||
LdapID: &missingLdapID,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&model.UserGroup{
|
||||
Name: "team-old",
|
||||
FriendlyName: "team-old",
|
||||
LdapID: &teamLdapID,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&model.UserGroup{
|
||||
Name: "old-group",
|
||||
FriendlyName: "old-group",
|
||||
LdapID: &oldGroupLdapID,
|
||||
}).Error)
|
||||
|
||||
require.NoError(t, service.SyncAll(t.Context()))
|
||||
|
||||
var alice model.User
|
||||
require.NoError(t, db.First(&alice, "ldap_id = ?", aliceLdapID).Error)
|
||||
assert.Equal(t, "alice", alice.Username)
|
||||
assert.Equal(t, new("alice@example.com"), alice.Email)
|
||||
assert.Equal(t, "Alice", alice.FirstName)
|
||||
assert.Equal(t, "Jones", alice.LastName)
|
||||
assert.Equal(t, "Alice Jones", alice.DisplayName)
|
||||
assert.True(t, alice.IsAdmin)
|
||||
assert.False(t, alice.Disabled)
|
||||
|
||||
var bob model.User
|
||||
require.NoError(t, db.First(&bob, "ldap_id = ?", "u-bob").Error)
|
||||
assert.Equal(t, "bob", bob.Username)
|
||||
assert.Equal(t, "Bob Brown", bob.DisplayName)
|
||||
|
||||
var missing model.User
|
||||
require.NoError(t, db.First(&missing, "ldap_id = ?", missingLdapID).Error)
|
||||
assert.True(t, missing.Disabled)
|
||||
|
||||
var oldGroupCount int64
|
||||
require.NoError(t, db.Model(&model.UserGroup{}).Where("ldap_id = ?", oldGroupLdapID).Count(&oldGroupCount).Error)
|
||||
assert.Zero(t, oldGroupCount)
|
||||
|
||||
var team model.UserGroup
|
||||
require.NoError(t, db.Preload("Users").First(&team, "ldap_id = ?", teamLdapID).Error)
|
||||
assert.Equal(t, "team", team.Name)
|
||||
assert.Equal(t, "team", team.FriendlyName)
|
||||
assert.ElementsMatch(t, []string{"alice", "bob"}, usernames(team.Users))
|
||||
}
|
||||
|
||||
func TestLdapServiceSyncAllHandlesDuplicateLDAPIDsInSingleRun(t *testing.T) {
|
||||
service, db := newTestLdapService(t, newFakeLDAPClient(
|
||||
ldapSearchResult(
|
||||
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-dup"},
|
||||
"uid": {"alice"},
|
||||
"mail": {"alice@example.com"},
|
||||
"givenName": {"Alice"},
|
||||
"sn": {"Doe"},
|
||||
"displayName": {"Alice Doe"},
|
||||
}),
|
||||
ldapEntry("uid=alice,ou=people,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"u-dup"},
|
||||
"uid": {"alice"},
|
||||
"mail": {"alice@example.com"},
|
||||
"givenName": {"Alicia"},
|
||||
"sn": {"Doe"},
|
||||
"displayName": {"Alicia Doe"},
|
||||
}),
|
||||
),
|
||||
ldapSearchResult(
|
||||
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"g-dup"},
|
||||
"cn": {"team"},
|
||||
"member": {"uid=alice,ou=people,dc=example,dc=com"},
|
||||
}),
|
||||
ldapEntry("cn=team,ou=groups,dc=example,dc=com", map[string][]string{
|
||||
"entryUUID": {"g-dup"},
|
||||
"cn": {"team-renamed"},
|
||||
"member": {"uid=alice,ou=people,dc=example,dc=com"},
|
||||
}),
|
||||
),
|
||||
))
|
||||
|
||||
require.NoError(t, service.SyncAll(t.Context()))
|
||||
|
||||
var users []model.User
|
||||
require.NoError(t, db.Find(&users, "ldap_id = ?", "u-dup").Error)
|
||||
require.Len(t, users, 1)
|
||||
assert.Equal(t, "alice", users[0].Username)
|
||||
assert.Equal(t, "Alicia", users[0].FirstName)
|
||||
assert.Equal(t, "Alicia Doe", users[0].DisplayName)
|
||||
|
||||
var groups []model.UserGroup
|
||||
require.NoError(t, db.Preload("Users").Find(&groups, "ldap_id = ?", "g-dup").Error)
|
||||
require.Len(t, groups, 1)
|
||||
assert.Equal(t, "team-renamed", groups[0].Name)
|
||||
assert.Equal(t, "team-renamed", groups[0].FriendlyName)
|
||||
assert.ElementsMatch(t, []string{"alice"}, usernames(groups[0].Users))
|
||||
}
|
||||
|
||||
func newTestLdapService(t *testing.T, client ldapClient) (*LdapService, *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
db := testutils.NewDatabaseForTest(t)
|
||||
|
||||
fileStorage, err := storage.NewDatabaseStorage(db)
|
||||
require.NoError(t, err)
|
||||
|
||||
appConfig := NewTestAppConfigService(&model.AppConfig{
|
||||
RequireUserEmail: model.AppConfigVariable{Value: "false"},
|
||||
LdapEnabled: model.AppConfigVariable{Value: "true"},
|
||||
LdapBase: model.AppConfigVariable{Value: "dc=example,dc=com"},
|
||||
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
|
||||
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
|
||||
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
|
||||
LdapAttributeUserUsername: model.AppConfigVariable{Value: "uid"},
|
||||
LdapAttributeUserEmail: model.AppConfigVariable{Value: "mail"},
|
||||
LdapAttributeUserFirstName: model.AppConfigVariable{Value: "givenName"},
|
||||
LdapAttributeUserLastName: model.AppConfigVariable{Value: "sn"},
|
||||
LdapAttributeUserDisplayName: model.AppConfigVariable{Value: "displayName"},
|
||||
LdapAttributeUserProfilePicture: model.AppConfigVariable{Value: "jpegPhoto"},
|
||||
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
|
||||
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{Value: "entryUUID"},
|
||||
LdapAttributeGroupName: model.AppConfigVariable{Value: "cn"},
|
||||
LdapAdminGroupName: model.AppConfigVariable{Value: "admins"},
|
||||
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
|
||||
})
|
||||
|
||||
groupService := NewUserGroupService(db, appConfig, nil)
|
||||
userService := NewUserService(
|
||||
db,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
appConfig,
|
||||
NewCustomClaimService(db),
|
||||
NewAppImagesService(map[string]string{}, fileStorage),
|
||||
nil,
|
||||
fileStorage,
|
||||
)
|
||||
|
||||
service := NewLdapService(db, &http.Client{}, appConfig, userService, groupService, fileStorage)
|
||||
service.clientFactory = func() (ldapClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
return service, db
|
||||
}
|
||||
|
||||
func newFakeLDAPClient(userResult, groupResult *ldap.SearchResult) ldapClient {
|
||||
return &fakeLDAPClient{
|
||||
searchFn: func(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) {
|
||||
switch searchRequest.Filter {
|
||||
case "(objectClass=person)":
|
||||
return userResult, nil
|
||||
case "(objectClass=groupOfNames)":
|
||||
return groupResult, nil
|
||||
default:
|
||||
return &ldap.SearchResult{}, nil
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ldapSearchResult(entries ...*ldap.Entry) *ldap.SearchResult {
|
||||
return &ldap.SearchResult{Entries: entries}
|
||||
}
|
||||
|
||||
func ldapEntry(dn string, attrs map[string][]string) *ldap.Entry {
|
||||
entry := &ldap.Entry{
|
||||
DN: dn,
|
||||
Attributes: make([]*ldap.EntryAttribute, 0, len(attrs)),
|
||||
}
|
||||
|
||||
for name, values := range attrs {
|
||||
entry.Attributes = append(entry.Attributes, &ldap.EntryAttribute{
|
||||
Name: name,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
|
||||
return entry
|
||||
}
|
||||
|
||||
func usernames(users []model.User) []string {
|
||||
result := make([]string, 0, len(users))
|
||||
for _, user := range users {
|
||||
result = append(result, user.Username)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func TestGetDNProperty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -64,10 +341,58 @@ func TestGetDNProperty(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getDNProperty(tt.property, tt.dn)
|
||||
if result != tt.expectedResult {
|
||||
t.Errorf("getDNProperty(%q, %q) = %q, want %q",
|
||||
tt.property, tt.dn, result, tt.expectedResult)
|
||||
}
|
||||
assert.Equalf(t, tt.expectedResult, result, "getDNProperty(%q, %q)", tt.property, tt.dn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeLDAPDN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "already normalized",
|
||||
input: "cn=alice,dc=example,dc=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "uppercase attribute types",
|
||||
input: "CN=Alice,DC=example,DC=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "spaces after commas",
|
||||
input: "cn=alice, dc=example, dc=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "uppercase types and spaces",
|
||||
input: "CN=Alice, DC=example, DC=com",
|
||||
expected: "cn=alice,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "multi-valued RDN",
|
||||
input: "cn=alice+uid=a123,dc=example,dc=com",
|
||||
expected: "cn=alice+uid=a123,dc=example,dc=com",
|
||||
},
|
||||
{
|
||||
name: "invalid DN falls back to lowercase+trim",
|
||||
input: " NOT A VALID DN ",
|
||||
expected: "not a valid dn",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := normalizeLDAPDN(tt.input)
|
||||
assert.Equalf(t, tt.expected, result, "normalizeLDAPDN(%q)", tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -98,9 +423,7 @@ func TestConvertLdapIdToString(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := convertLdapIdToString(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, got)
|
||||
}
|
||||
assert.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -404,7 +404,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
||||
}
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != input.ClientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
if authorizationCodeMetaData.ClientID != input.ClientID || authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
|
||||
return CreatedTokens{}, &common.OidcInvalidAuthorizationCodeError{}
|
||||
}
|
||||
|
||||
|
||||
25
backend/internal/service/scheduler.go
Normal file
25
backend/internal/service/scheduler.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v5"
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
)
|
||||
|
||||
// RegisterJobOpts holds optional configuration for registering a scheduled job.
|
||||
type RegisterJobOpts struct {
|
||||
// RunImmediately runs the job immediately after registration.
|
||||
RunImmediately bool
|
||||
// ExtraOptions are additional gocron job options.
|
||||
ExtraOptions []gocron.JobOption
|
||||
// BackOff is an optional backoff strategy. If non-nil, the job will be wrapped
|
||||
// with automatic retry logic using the provided backoff on transient failures.
|
||||
BackOff backoff.BackOff
|
||||
}
|
||||
|
||||
// Scheduler is an interface for registering and managing background jobs.
|
||||
type Scheduler interface {
|
||||
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, opts RegisterJobOpts) error
|
||||
RemoveJob(name string) error
|
||||
}
|
||||
@@ -34,11 +34,6 @@ const scimErrorBodyLimit = 4096
|
||||
|
||||
type scimSyncAction int
|
||||
|
||||
type Scheduler interface {
|
||||
RegisterJob(ctx context.Context, name string, def gocron.JobDefinition, job func(ctx context.Context) error, runImmediately bool, extraOptions ...gocron.JobOption) error
|
||||
RemoveJob(name string) error
|
||||
}
|
||||
|
||||
const (
|
||||
scimActionNone scimSyncAction = iota
|
||||
scimActionCreated
|
||||
@@ -149,7 +144,7 @@ func (s *ScimService) ScheduleSync() {
|
||||
|
||||
err := s.scheduler.RegisterJob(
|
||||
context.Background(), jobName,
|
||||
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, false)
|
||||
gocron.OneTimeJob(gocron.OneTimeJobStartDateTime(start)), s.SyncAll, RegisterJobOpts{})
|
||||
|
||||
if err != nil {
|
||||
slog.Error("Failed to schedule SCIM sync", slog.Any("error", err))
|
||||
@@ -168,7 +163,8 @@ func (s *ScimService) SyncAll(ctx context.Context) error {
|
||||
errs = append(errs, ctx.Err())
|
||||
break
|
||||
}
|
||||
if err := s.SyncServiceProvider(ctx, provider.ID); err != nil {
|
||||
err = s.SyncServiceProvider(ctx, provider.ID)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to sync SCIM provider %s: %w", provider.ID, err))
|
||||
}
|
||||
}
|
||||
@@ -210,26 +206,20 @@ func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID
|
||||
}
|
||||
|
||||
var errs []error
|
||||
var userStats scimSyncStats
|
||||
var groupStats scimSyncStats
|
||||
|
||||
// Sync users first, so that groups can reference them
|
||||
if stats, err := s.syncUsers(ctx, provider, users, &userResources); err != nil {
|
||||
errs = append(errs, err)
|
||||
userStats = stats
|
||||
} else {
|
||||
userStats = stats
|
||||
}
|
||||
|
||||
stats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
|
||||
userStats, err := s.syncUsers(ctx, provider, users, &userResources)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
groupStats, err := s.syncGroups(ctx, provider, groups, groupResources.Resources, userResources.Resources)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
groupStats = stats
|
||||
} else {
|
||||
groupStats = stats
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
err = errors.Join(errs...)
|
||||
slog.WarnContext(ctx, "SCIM sync completed with errors",
|
||||
slog.String("provider_id", provider.ID),
|
||||
slog.Int("error_count", len(errs)),
|
||||
@@ -240,12 +230,14 @@ func (s *ScimService) SyncServiceProvider(ctx context.Context, serviceProviderID
|
||||
slog.Int("groups_updated", groupStats.Updated),
|
||||
slog.Int("groups_deleted", groupStats.Deleted),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
slog.Any("error", err),
|
||||
)
|
||||
return errors.Join(errs...)
|
||||
return err
|
||||
}
|
||||
|
||||
provider.LastSyncedAt = new(datatype.DateTime(time.Now()))
|
||||
if err := s.db.WithContext(ctx).Save(&provider).Error; err != nil {
|
||||
err = s.db.WithContext(ctx).Save(&provider).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -273,7 +265,7 @@ func (s *ScimService) syncUsers(
|
||||
|
||||
// Update or create users
|
||||
for _, u := range users {
|
||||
existing := getResourceByExternalID[dto.ScimUser](u.ID, resourceList.Resources)
|
||||
existing := getResourceByExternalID(u.ID, resourceList.Resources)
|
||||
|
||||
action, created, err := s.syncUser(ctx, provider, u, existing)
|
||||
if created != nil && existing == nil {
|
||||
@@ -434,7 +426,7 @@ func (s *ScimService) syncGroup(
|
||||
// Prepare group members
|
||||
members := make([]dto.ScimGroupMember, len(group.Users))
|
||||
for i, user := range group.Users {
|
||||
userResource := getResourceByExternalID[dto.ScimUser](user.ID, userResources)
|
||||
userResource := getResourceByExternalID(user.ID, userResources)
|
||||
if userResource == nil {
|
||||
// Groups depend on user IDs already being provisioned
|
||||
return scimActionNone, fmt.Errorf("cannot sync group %s: user %s is not provisioned in SCIM provider", group.ID, user.ID)
|
||||
|
||||
@@ -96,7 +96,10 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -126,7 +129,10 @@ func (s *UserGroupService) createInternal(ctx context.Context, input dto.UserGro
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -175,7 +181,10 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -238,7 +247,10 @@ func (s *UserGroupService) updateUsersInternal(ctx context.Context, id string, u
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -315,6 +327,9 @@ func (s *UserGroupService) UpdateAllowedOidcClient(ctx context.Context, id strin
|
||||
return model.UserGroup{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -225,7 +225,10 @@ func (s *UserService) deleteUserInternal(ctx context.Context, tx *gorm.DB, userI
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -310,7 +313,10 @@ func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCrea
|
||||
}
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -456,7 +462,10 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
return user, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -515,7 +524,10 @@ func (s *UserService) UpdateUserGroups(ctx context.Context, id string, userGroup
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -576,7 +588,10 @@ func (s *UserService) disableUserInternal(ctx context.Context, tx *gorm.DB, user
|
||||
return err
|
||||
}
|
||||
|
||||
s.scimService.ScheduleSync()
|
||||
if s.scimService != nil {
|
||||
s.scimService.ScheduleSync()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="background-color:#FBFBFB"><!--$--><!--html--><!--head--><!--body--><table border="0" width="100%" cellPadding="0" cellSpacing="0" role="presentation" align="center"><tbody><tr><td style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px"><img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">API Key Expiring Soon</h1></td><td align="right" data-id="__react-email-column"><p style="font-size:12px;line-height:24px;background-color:#ffd966;color:#7f6000;padding:1px 12px;border-radius:50px;display:inline-block;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Warning</p></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Hello <!-- -->{{.Data.Name}}<!-- -->, <br/>This is a reminder that your API key <strong>{{.Data.APIKeyName}}</strong> <!-- -->will expire on <strong>{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}</strong>.</p><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Please generate a new API key if you need continued access.</p></div></td></tr></tbody></table></td></tr></tbody></table><!--/$--></body></html>{{end}}
|
||||
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="background-color:#FBFBFB"><!--$--><!--html--><!--head--><!--body--><table border="0" width="100%" cellPadding="0" cellSpacing="0" role="presentation" align="center"><tbody><tr><td style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px"><img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">API Key Expiring Soon</h1></td><td align="right" data-id="__react-email-column"><p style="font-size:12px;line-height:24px;background-color:#ffd966;color:#7f6000;padding:1px 12px;border-radius:50px;display:inline-block;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Warning</p></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Hello <!-- -->{{.Data.Name}}<!-- -->, <br/>This is a reminder that your API key <strong>{{.Data.ApiKeyName}}</strong> <!-- -->will expire on <strong>{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}</strong>.</p><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Please generate a new API key if you need continued access.</p></div></td></tr></tbody></table></td></tr></tbody></table><!--/$--></body></html>{{end}}
|
||||
@@ -6,6 +6,6 @@ API KEY EXPIRING SOON
|
||||
Warning
|
||||
|
||||
Hello {{.Data.Name}},
|
||||
This is a reminder that your API key {{.Data.APIKeyName}} will expire on {{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}.
|
||||
This is a reminder that your API key {{.Data.ApiKeyName}} will expire on {{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}.
|
||||
|
||||
Please generate a new API key if you need continued access.{{end}}
|
||||
@@ -0,0 +1 @@
|
||||
-- No-op
|
||||
@@ -0,0 +1,6 @@
|
||||
CREATE INDEX IF NOT EXISTS idx_webauthn_sessions_expires_at ON webauthn_sessions (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_one_time_access_tokens_expires_at ON one_time_access_tokens (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_oidc_authorization_codes_expires_at ON oidc_authorization_codes (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_oidc_refresh_tokens_expires_at ON oidc_refresh_tokens (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_reauthentication_tokens_expires_at ON reauthentication_tokens (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_expires_at ON email_verification_tokens (expires_at);
|
||||
@@ -0,0 +1 @@
|
||||
-- No-op
|
||||
@@ -0,0 +1,12 @@
|
||||
PRAGMA foreign_keys= OFF;
|
||||
BEGIN;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_webauthn_sessions_expires_at ON webauthn_sessions (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_one_time_access_tokens_expires_at ON one_time_access_tokens (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_oidc_authorization_codes_expires_at ON oidc_authorization_codes (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_oidc_refresh_tokens_expires_at ON oidc_refresh_tokens (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_reauthentication_tokens_expires_at ON reauthentication_tokens (expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_email_verification_tokens_expires_at ON email_verification_tokens (expires_at);
|
||||
|
||||
COMMIT;
|
||||
PRAGMA foreign_keys=ON;
|
||||
Reference in New Issue
Block a user