mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-27 09:46:36 +00:00
feat: disable/enable users (#437)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
This commit is contained in:
@@ -63,7 +63,7 @@ func initRouter(ctx context.Context, db *gorm.DB, appConfigService *service.AppC
|
||||
job.RegisterFileCleanupJobs(ctx, db)
|
||||
|
||||
// Initialize middleware for specific routes
|
||||
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService)
|
||||
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, userService, jwtService)
|
||||
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
|
||||
|
||||
// Set up API routes
|
||||
|
||||
@@ -82,11 +82,6 @@ type FileTypeNotSupportedError struct{}
|
||||
func (e *FileTypeNotSupportedError) Error() string { return "file type not supported" }
|
||||
func (e *FileTypeNotSupportedError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type InvalidCredentialsError struct{}
|
||||
|
||||
func (e *InvalidCredentialsError) Error() string { return "no user found with provided credentials" }
|
||||
func (e *InvalidCredentialsError) HttpStatusCode() int { return 400 }
|
||||
|
||||
type FileTooLargeError struct {
|
||||
MaxSize string
|
||||
}
|
||||
@@ -229,8 +224,7 @@ type InvalidUUIDError struct{}
|
||||
func (e *InvalidUUIDError) Error() string {
|
||||
return "Invalid UUID"
|
||||
}
|
||||
|
||||
type InvalidEmailError struct{}
|
||||
func (e *InvalidUUIDError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type OneTimeAccessDisabledError struct{}
|
||||
|
||||
@@ -244,31 +238,34 @@ type InvalidAPIKeyError struct{}
|
||||
func (e *InvalidAPIKeyError) Error() string {
|
||||
return "Invalid Api Key"
|
||||
}
|
||||
func (e *InvalidAPIKeyError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type NoAPIKeyProvidedError struct{}
|
||||
|
||||
func (e *NoAPIKeyProvidedError) Error() string {
|
||||
return "No API Key Provided"
|
||||
}
|
||||
func (e *NoAPIKeyProvidedError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type APIKeyNotFoundError struct{}
|
||||
|
||||
func (e *APIKeyNotFoundError) Error() string {
|
||||
return "API Key Not Found"
|
||||
}
|
||||
func (e *APIKeyNotFoundError) HttpStatusCode() int { return http.StatusUnauthorized }
|
||||
|
||||
type APIKeyExpirationDateError struct{}
|
||||
|
||||
func (e *APIKeyExpirationDateError) Error() string {
|
||||
return "API Key expiration time must be in the future"
|
||||
}
|
||||
func (e *APIKeyExpirationDateError) HttpStatusCode() int { return http.StatusBadRequest }
|
||||
|
||||
type OidcInvalidRefreshTokenError struct{}
|
||||
|
||||
func (e *OidcInvalidRefreshTokenError) Error() string {
|
||||
return "refresh token is invalid or expired"
|
||||
}
|
||||
|
||||
func (e *OidcInvalidRefreshTokenError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
@@ -278,7 +275,6 @@ type OidcMissingRefreshTokenError struct{}
|
||||
func (e *OidcMissingRefreshTokenError) Error() string {
|
||||
return "refresh token is required"
|
||||
}
|
||||
|
||||
func (e *OidcMissingRefreshTokenError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
@@ -288,7 +284,15 @@ type OidcMissingAuthorizationCodeError struct{}
|
||||
func (e *OidcMissingAuthorizationCodeError) Error() string {
|
||||
return "authorization code is required"
|
||||
}
|
||||
|
||||
func (e *OidcMissingAuthorizationCodeError) HttpStatusCode() int {
|
||||
return http.StatusBadRequest
|
||||
}
|
||||
|
||||
type UserDisabledError struct{}
|
||||
|
||||
func (e *UserDisabledError) Error() string {
|
||||
return "User account is disabled"
|
||||
}
|
||||
func (e *UserDisabledError) HttpStatusCode() int {
|
||||
return http.StatusForbidden
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ type AppConfigUpdateDto struct {
|
||||
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
|
||||
LdapAttributeGroupName string `json:"ldapAttributeGroupName"`
|
||||
LdapAttributeAdminGroup string `json:"ldapAttributeAdminGroup"`
|
||||
LdapSoftDeleteUsers string `json:"ldapSoftDeleteUsers"`
|
||||
EmailOneTimeAccessEnabled string `json:"emailOneTimeAccessEnabled" binding:"required"`
|
||||
EmailLoginNotificationEnabled string `json:"emailLoginNotificationEnabled" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type UserDto struct {
|
||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||
UserGroups []UserGroupDto `json:"userGroups"`
|
||||
LdapID *string `json:"ldapId"`
|
||||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
|
||||
type UserCreateDto struct {
|
||||
@@ -22,6 +23,7 @@ type UserCreateDto struct {
|
||||
LastName string `json:"lastName" binding:"required,min=1,max=50"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Locale *string `json:"locale"`
|
||||
Disabled bool `json:"disabled"`
|
||||
LdapID string `json:"-"`
|
||||
}
|
||||
|
||||
|
||||
@@ -41,7 +41,10 @@ func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userI
|
||||
return "", false, &common.NotSignedInError{}
|
||||
}
|
||||
|
||||
// Check if the user is an admin
|
||||
if user.Disabled {
|
||||
return "", false, &common.UserDisabledError{}
|
||||
}
|
||||
|
||||
if adminRequired && !user.IsAdmin {
|
||||
return "", false, &common.MissingPermissionError{}
|
||||
}
|
||||
|
||||
@@ -19,11 +19,12 @@ type AuthOptions struct {
|
||||
|
||||
func NewAuthMiddleware(
|
||||
apiKeyService *service.ApiKeyService,
|
||||
userService *service.UserService,
|
||||
jwtService *service.JwtService,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
apiKeyMiddleware: NewApiKeyAuthMiddleware(apiKeyService, jwtService),
|
||||
jwtMiddleware: NewJwtAuthMiddleware(jwtService),
|
||||
jwtMiddleware: NewJwtAuthMiddleware(jwtService, userService),
|
||||
options: AuthOptions{
|
||||
AdminRequired: true,
|
||||
SuccessOptional: false,
|
||||
@@ -57,12 +58,13 @@ func (m *AuthMiddleware) WithSuccessOptional() *AuthMiddleware {
|
||||
|
||||
func (m *AuthMiddleware) Add() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// First try JWT auth
|
||||
userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
|
||||
if err == nil {
|
||||
// JWT auth succeeded, continue with the request
|
||||
c.Set("userID", userID)
|
||||
c.Set("userIsAdmin", isAdmin)
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -70,9 +72,11 @@ func (m *AuthMiddleware) Add() gin.HandlerFunc {
|
||||
// JWT auth failed, try API key auth
|
||||
userID, isAdmin, err = m.apiKeyMiddleware.Verify(c, m.options.AdminRequired)
|
||||
if err == nil {
|
||||
// API key auth succeeded, continue with the request
|
||||
c.Set("userID", userID)
|
||||
c.Set("userIsAdmin", isAdmin)
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
)
|
||||
|
||||
type JwtAuthMiddleware struct {
|
||||
jwtService *service.JwtService
|
||||
userService *service.UserService
|
||||
jwtService *service.JwtService
|
||||
}
|
||||
|
||||
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
|
||||
return &JwtAuthMiddleware{jwtService: jwtService}
|
||||
func NewJwtAuthMiddleware(jwtService *service.JwtService, userService *service.UserService) *JwtAuthMiddleware {
|
||||
return &JwtAuthMiddleware{jwtService: jwtService, userService: userService}
|
||||
}
|
||||
|
||||
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
|
||||
@@ -55,12 +56,16 @@ func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is an admin
|
||||
isAdmin, err = service.GetIsAdmin(token)
|
||||
user, err := m.userService.GetUser(c, subject)
|
||||
if err != nil {
|
||||
return "", false, &common.TokenInvalidError{}
|
||||
return "", false, &common.NotSignedInError{}
|
||||
}
|
||||
if adminRequired && !isAdmin {
|
||||
|
||||
if user.Disabled {
|
||||
return "", false, &common.UserDisabledError{}
|
||||
}
|
||||
|
||||
if adminRequired && !user.IsAdmin {
|
||||
return "", false, &common.MissingPermissionError{}
|
||||
}
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ type AppConfig struct {
|
||||
LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
|
||||
LdapAttributeGroupName AppConfigVariable `key:"ldapAttributeGroupName"`
|
||||
LdapAttributeAdminGroup AppConfigVariable `key:"ldapAttributeAdminGroup"`
|
||||
LdapSoftDeleteUsers AppConfigVariable `key:"ldapSoftDeleteUsers"`
|
||||
}
|
||||
|
||||
func (c *AppConfig) ToAppConfigVariableSlice(showAll bool) []AppConfigVariable {
|
||||
|
||||
@@ -19,6 +19,7 @@ type User struct {
|
||||
IsAdmin bool `sortable:"true"`
|
||||
Locale *string
|
||||
LdapID *string
|
||||
Disabled bool `sortable:"true"`
|
||||
|
||||
CustomClaims []CustomClaim
|
||||
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
||||
|
||||
@@ -93,6 +93,7 @@ func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
|
||||
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
|
||||
LdapAttributeGroupName: model.AppConfigVariable{},
|
||||
LdapAttributeAdminGroup: model.AppConfigVariable{},
|
||||
LdapSoftDeleteUsers: model.AppConfigVariable{Value: "true"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -279,6 +279,22 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
Where("ldap_id = ?", ldapId).
|
||||
First(&databaseUser).
|
||||
Error
|
||||
|
||||
// If a user is found (even if disabled), enable them since they're now back in LDAP
|
||||
if databaseUser.ID != "" && databaseUser.Disabled {
|
||||
// Use the transaction instead of the direct context
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", databaseUser.ID).
|
||||
Update("disabled", false).
|
||||
Error
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to enable user %s: %v", databaseUser.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// This could error with ErrRecordNotFound and we want to ignore that here
|
||||
return fmt.Errorf("failed to query for LDAP user ID '%s': %w", ldapId, err)
|
||||
@@ -336,24 +352,32 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
||||
err = tx.
|
||||
WithContext(ctx).
|
||||
Find(&ldapUsersInDb, "ldap_id IS NOT NULL").
|
||||
Select("ldap_id").
|
||||
Select("id, username, ldap_id, disabled").
|
||||
Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch users from database: %w", err)
|
||||
}
|
||||
|
||||
// Delete users that no longer exist in LDAP
|
||||
// Mark users as disabled or delete users that no longer exist in LDAP
|
||||
for _, user := range ldapUsersInDb {
|
||||
// Skip if the user ID exists in the fetched LDAP results
|
||||
if _, exists := ldapUserIDs[*user.LdapID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
err = s.userService.deleteUserInternal(ctx, user.ID, true, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user '%s': %w", user.Username, err)
|
||||
if dbConfig.LdapSoftDeleteUsers.IsTrue() {
|
||||
err = s.userService.DisableUser(ctx, user.ID, tx)
|
||||
if err != nil {
|
||||
log.Printf("Failed to disable user %s: %v", user.Username, err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
err = s.userService.DeleteUser(ctx, user.ID, true)
|
||||
if err != nil {
|
||||
log.Printf("Failed to delete user %s: %v", user.Username, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Deleted user '%s'", user.Username)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -38,14 +38,19 @@ func NewUserService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditL
|
||||
|
||||
func (s *UserService) ListUsers(ctx context.Context, searchTerm string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.User, utils.PaginationResponse, error) {
|
||||
var users []model.User
|
||||
query := s.db.WithContext(ctx).Model(&model.User{})
|
||||
query := s.db.WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Preload("UserGroups").
|
||||
Preload("CustomClaims")
|
||||
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
|
||||
query = query.Where("email LIKE ? OR first_name LIKE ? OR last_name LIKE ? OR username LIKE ?",
|
||||
searchPattern, searchPattern, searchPattern, searchPattern)
|
||||
}
|
||||
|
||||
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &users)
|
||||
|
||||
return users, pagination, err
|
||||
}
|
||||
|
||||
@@ -170,9 +175,28 @@ func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteUser(ctx context.Context, userID string, allowLdapDelete bool) error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
return s.deleteUserInternal(ctx, userID, allowLdapDelete, tx)
|
||||
})
|
||||
tx := s.db.Begin()
|
||||
|
||||
var user model.User
|
||||
if err := tx.WithContext(ctx).First(&user, "id = ?", userID).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// Only soft-delete if user is LDAP and soft-delete is enabled and not allowing hard delete
|
||||
if user.LdapID != nil && s.appConfigService.GetDbConfig().LdapSoftDeleteUsers.IsTrue() && !allowLdapDelete {
|
||||
if !user.Disabled {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("LDAP user must be disabled before deletion")
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, hard delete (local users or LDAP users when allowed)
|
||||
if err := s.deleteUserInternal(ctx, userID, allowLdapDelete, tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
func (s *UserService) deleteUserInternal(ctx context.Context, userID string, allowLdapDelete bool, tx *gorm.DB) error {
|
||||
@@ -187,8 +211,8 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
|
||||
return fmt.Errorf("failed to load user to delete: %w", err)
|
||||
}
|
||||
|
||||
// Disallow deleting the user if it is an LDAP user and LDAP is enabled
|
||||
if !allowLdapDelete && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
||||
// Disallow deleting the user if it is an LDAP user, LDAP is enabled, and the user is not disabled
|
||||
if !allowLdapDelete && !user.Disabled && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
|
||||
return &common.LdapUserUpdateError{}
|
||||
}
|
||||
|
||||
@@ -299,6 +323,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
||||
user.Locale = updatedUser.Locale
|
||||
if !updateOwnUser {
|
||||
user.IsAdmin = updatedUser.IsAdmin
|
||||
user.Disabled = updatedUser.Disabled
|
||||
}
|
||||
|
||||
err = tx.
|
||||
@@ -606,3 +631,11 @@ func (s *UserService) ResetProfilePicture(userID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) DisableUser(ctx context.Context, userID string, tx *gorm.DB) error {
|
||||
return tx.WithContext(ctx).
|
||||
Model(&model.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("disabled", true).
|
||||
Error
|
||||
}
|
||||
|
||||
@@ -244,6 +244,10 @@ func (s *WebAuthnService) VerifyLogin(ctx context.Context, sessionID string, cre
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if user.Disabled {
|
||||
return model.User{}, "", &common.UserDisabledError{}
|
||||
}
|
||||
|
||||
token, err := s.jwtService.GenerateAccessToken(*user)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP INDEX idx_users_disabled;
|
||||
|
||||
ALTER TABLE users
|
||||
DROP COLUMN disabled;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE users
|
||||
ADD COLUMN IF NOT EXISTS disabled BOOLEAN DEFAULT FALSE NOT NULL;
|
||||
@@ -0,0 +1,4 @@
|
||||
DROP INDEX idx_users_disabled;
|
||||
|
||||
ALTER TABLE users
|
||||
DROP COLUMN disabled;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE users
|
||||
ADD COLUMN disabled NUMERIC DEFAULT FALSE NOT NULL;
|
||||
Reference in New Issue
Block a user