refactor: simplify app_config service and fix race conditions (#423)

This commit is contained in:
Alessandro (Ale) Segala
2025-04-10 04:41:22 -07:00
committed by GitHub
parent 4ba68938dd
commit f83bab9e17
28 changed files with 1263 additions and 560 deletions

View File

@@ -3,30 +3,35 @@ package service
import (
"context"
"errors"
"fmt"
"log"
"mime/multipart"
"os"
"reflect"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type AppConfigService struct {
DbConfig *model.AppConfig
dbConfig atomic.Pointer[model.AppConfig]
db *gorm.DB
}
func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
service := &AppConfigService{
DbConfig: &defaultDbConfig,
db: db,
db: db,
}
err := service.InitDbConfig(ctx)
err := service.LoadDbConfig(ctx)
if err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
@@ -34,170 +39,109 @@ func NewAppConfigService(ctx context.Context, db *gorm.DB) *AppConfigService {
return service
}
var defaultDbConfig = model.AppConfig{
// General
AppName: model.AppConfigVariable{
Key: "appName",
Type: "string",
IsPublic: true,
DefaultValue: "Pocket ID",
},
SessionDuration: model.AppConfigVariable{
Key: "sessionDuration",
Type: "number",
DefaultValue: "60",
},
EmailsVerified: model.AppConfigVariable{
Key: "emailsVerified",
Type: "bool",
DefaultValue: "false",
},
AllowOwnAccountEdit: model.AppConfigVariable{
Key: "allowOwnAccountEdit",
Type: "bool",
IsPublic: true,
DefaultValue: "true",
},
// Internal
BackgroundImageType: model.AppConfigVariable{
Key: "backgroundImageType",
Type: "string",
IsInternal: true,
DefaultValue: "jpg",
},
LogoLightImageType: model.AppConfigVariable{
Key: "logoLightImageType",
Type: "string",
IsInternal: true,
DefaultValue: "svg",
},
LogoDarkImageType: model.AppConfigVariable{
Key: "logoDarkImageType",
Type: "string",
IsInternal: true,
DefaultValue: "svg",
},
// Email
SmtpHost: model.AppConfigVariable{
Key: "smtpHost",
Type: "string",
},
SmtpPort: model.AppConfigVariable{
Key: "smtpPort",
Type: "number",
},
SmtpFrom: model.AppConfigVariable{
Key: "smtpFrom",
Type: "string",
},
SmtpUser: model.AppConfigVariable{
Key: "smtpUser",
Type: "string",
},
SmtpPassword: model.AppConfigVariable{
Key: "smtpPassword",
Type: "string",
},
SmtpTls: model.AppConfigVariable{
Key: "smtpTls",
Type: "string",
DefaultValue: "none",
},
SmtpSkipCertVerify: model.AppConfigVariable{
Key: "smtpSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
EmailLoginNotificationEnabled: model.AppConfigVariable{
Key: "emailLoginNotificationEnabled",
Type: "bool",
DefaultValue: "false",
},
EmailOneTimeAccessEnabled: model.AppConfigVariable{
Key: "emailOneTimeAccessEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
// LDAP
LdapEnabled: model.AppConfigVariable{
Key: "ldapEnabled",
Type: "bool",
IsPublic: true,
DefaultValue: "false",
},
LdapUrl: model.AppConfigVariable{
Key: "ldapUrl",
Type: "string",
},
LdapBindDn: model.AppConfigVariable{
Key: "ldapBindDn",
Type: "string",
},
LdapBindPassword: model.AppConfigVariable{
Key: "ldapBindPassword",
Type: "string",
},
LdapBase: model.AppConfigVariable{
Key: "ldapBase",
Type: "string",
},
LdapUserSearchFilter: model.AppConfigVariable{
Key: "ldapUserSearchFilter",
Type: "string",
DefaultValue: "(objectClass=person)",
},
LdapUserGroupSearchFilter: model.AppConfigVariable{
Key: "ldapUserGroupSearchFilter",
Type: "string",
DefaultValue: "(objectClass=groupOfNames)",
},
LdapSkipCertVerify: model.AppConfigVariable{
Key: "ldapSkipCertVerify",
Type: "bool",
DefaultValue: "false",
},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeUserUniqueIdentifier",
Type: "string",
},
LdapAttributeUserUsername: model.AppConfigVariable{
Key: "ldapAttributeUserUsername",
Type: "string",
},
LdapAttributeUserEmail: model.AppConfigVariable{
Key: "ldapAttributeUserEmail",
Type: "string",
},
LdapAttributeUserFirstName: model.AppConfigVariable{
Key: "ldapAttributeUserFirstName",
Type: "string",
},
LdapAttributeUserLastName: model.AppConfigVariable{
Key: "ldapAttributeUserLastName",
Type: "string",
},
LdapAttributeUserProfilePicture: model.AppConfigVariable{
Key: "ldapAttributeUserProfilePicture",
Type: "string",
},
LdapAttributeGroupMember: model.AppConfigVariable{
Key: "ldapAttributeGroupMember",
Type: "string",
DefaultValue: "member",
},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{
Key: "ldapAttributeGroupUniqueIdentifier",
Type: "string",
},
LdapAttributeGroupName: model.AppConfigVariable{
Key: "ldapAttributeGroupName",
Type: "string",
},
LdapAttributeAdminGroup: model.AppConfigVariable{
Key: "ldapAttributeAdminGroup",
Type: "string",
},
// GetDbConfig returns the application configuration.
// Important: Treat the object as read-only: do not modify its properties directly!
func (s *AppConfigService) GetDbConfig() *model.AppConfig {
v := s.dbConfig.Load()
if v == nil {
// This indicates a development-time error
panic("called GetDbConfig before DbConfig is loaded")
}
return v
}
func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
// Values are the default ones
return &model.AppConfig{
// General
AppName: model.AppConfigVariable{Value: "Pocket ID"},
SessionDuration: model.AppConfigVariable{Value: "60"},
EmailsVerified: model.AppConfigVariable{Value: "false"},
AllowOwnAccountEdit: model.AppConfigVariable{Value: "true"},
// Internal
BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
LogoLightImageType: model.AppConfigVariable{Value: "svg"},
LogoDarkImageType: model.AppConfigVariable{Value: "svg"},
// Email
SmtpHost: model.AppConfigVariable{},
SmtpPort: model.AppConfigVariable{},
SmtpFrom: model.AppConfigVariable{},
SmtpUser: model.AppConfigVariable{},
SmtpPassword: model.AppConfigVariable{},
SmtpTls: model.AppConfigVariable{Value: "none"},
SmtpSkipCertVerify: model.AppConfigVariable{Value: "false"},
EmailLoginNotificationEnabled: model.AppConfigVariable{Value: "false"},
EmailOneTimeAccessEnabled: model.AppConfigVariable{Value: "false"},
// LDAP
LdapEnabled: model.AppConfigVariable{Value: "false"},
LdapUrl: model.AppConfigVariable{},
LdapBindDn: model.AppConfigVariable{},
LdapBindPassword: model.AppConfigVariable{},
LdapBase: model.AppConfigVariable{},
LdapUserSearchFilter: model.AppConfigVariable{Value: "(objectClass=person)"},
LdapUserGroupSearchFilter: model.AppConfigVariable{Value: "(objectClass=groupOfNames)"},
LdapSkipCertVerify: model.AppConfigVariable{Value: "false"},
LdapAttributeUserUniqueIdentifier: model.AppConfigVariable{},
LdapAttributeUserUsername: model.AppConfigVariable{},
LdapAttributeUserEmail: model.AppConfigVariable{},
LdapAttributeUserFirstName: model.AppConfigVariable{},
LdapAttributeUserLastName: model.AppConfigVariable{},
LdapAttributeUserProfilePicture: model.AppConfigVariable{},
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
LdapAttributeGroupName: model.AppConfigVariable{},
LdapAttributeAdminGroup: model.AppConfigVariable{},
}
}
func (s *AppConfigService) updateAppConfigStartTransaction(ctx context.Context) (tx *gorm.DB, err error) {
// We start a transaction before doing any work, to ensure that we are the only ones updating the data in the database
// This works across multiple processes too
tx = s.db.Begin()
err = tx.Error
if err != nil {
return nil, fmt.Errorf("failed to begin database transaction: %w", err)
}
// With SQLite there's nothing else we need to do, because a transaction blocks the entire database
// However, with Postgres we need to manually lock the table to prevent others from doing the same
switch s.db.Name() {
case "postgres":
// We do not use "NOWAIT" so this blocks until the database is available, or the context is canceled
// Here we use a context with a 10s timeout in case the database is blocked for longer
lockCtx, lockCancel := context.WithTimeout(ctx, 10*time.Second)
defer lockCancel()
err = tx.
WithContext(lockCtx).
Exec("LOCK TABLE app_config_variables IN ACCESS EXCLUSIVE MODE").
Error
if err != nil {
tx.Rollback()
return nil, fmt.Errorf("failed to acquire lock on app_config_variables table: %w", err)
}
default:
// Nothing to do here
}
return tx, nil
}
func (s *AppConfigService) updateAppConfigUpdateDatabase(ctx context.Context, tx *gorm.DB, dbUpdate *[]model.AppConfigVariable) error {
err := tx.
WithContext(ctx).
Clauses(clause.OnConflict{
// Perform an "upsert" if the key already exists, replacing the value
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.AssignmentColumns([]string{"value"}),
}).
Create(&dbUpdate).
Error
if err != nil {
return fmt.Errorf("failed to update config in database: %w", err)
}
return nil
}
func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
@@ -205,106 +149,168 @@ func (s *AppConfigService) UpdateAppConfig(ctx context.Context, input dto.AppCon
return nil, &common.UiConfigDisabledError{}
}
tx := s.db.Begin()
// If EmailLoginNotificationEnabled is set to false (explicitly), disable the EmailOneTimeAccessEnabled
if input.EmailLoginNotificationEnabled == "false" {
input.EmailOneTimeAccessEnabled = "false"
}
// Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return nil, err
}
defer func() {
tx.Rollback()
}()
var err error
// From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return nil, fmt.Errorf("failed to reload config from database: %w", err)
}
defaultCfg := s.getDefaultDbConfig()
// Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
rt := reflect.ValueOf(input).Type()
rv := reflect.ValueOf(input)
savedConfigVariables := make([]model.AppConfigVariable, 0, rt.NumField())
dbUpdate := make([]model.AppConfigVariable, 0, rt.NumField())
for i := range rt.NumField() {
field := rt.Field(i)
key := field.Tag.Get("json")
value := rv.FieldByName(field.Name).String()
// If the emailEnabled is set to false, disable the emailOneTimeAccessEnabled
if key == s.DbConfig.EmailOneTimeAccessEnabled.Key {
if rv.FieldByName("EmailEnabled").String() == "false" {
value = "false"
}
// Get the value of the json tag, taking only what's before the comma
key, _, _ := strings.Cut(field.Tag.Get("json"), ",")
// Update the in-memory config value
// If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
// Ignore errors here as we know the key exists
defaultValue, _ := defaultCfg.FieldByKey(key)
err = cfg.UpdateField(key, defaultValue, true)
} else {
err = cfg.UpdateField(key, value, true)
}
var appConfigVariable model.AppConfigVariable
err = tx.
WithContext(ctx).
First(&appConfigVariable, "key = ? AND is_internal = false", key).
Error
if err != nil {
return nil, err
// If we tried to update an internal field, ignore the error (and do not update in the DB)
if errors.Is(err, model.AppConfigInternalForbiddenError{}) {
continue
} else if err != nil {
return nil, fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
}
appConfigVariable.Value = value
err = tx.
WithContext(ctx).
Save(&appConfigVariable).
Error
if err != nil {
return nil, err
}
savedConfigVariables = append(savedConfigVariables, appConfigVariable)
// We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
Key: key,
Value: value,
})
}
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil {
return nil, err
}
// Commit the changes to the DB, then finally save the updated config in the object
err = tx.Commit().Error
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}
err = s.LoadDbConfigFromDb()
if err != nil {
return nil, err
}
s.dbConfig.Store(cfg)
return savedConfigVariables, nil
// Return the updated config
res := cfg.ToAppConfigVariableSlice(true)
return res, nil
}
func (s *AppConfigService) updateImageType(ctx context.Context, imageName string, fileType string) error {
key := imageName + "ImageType"
err := s.db.
WithContext(ctx).
Model(&model.AppConfigVariable{}).
Where("key = ?", key).
Update("value", fileType).
Error
// UpdateAppConfigValues
func (s *AppConfigService) UpdateAppConfigValues(ctx context.Context, keysAndValues ...string) error {
if common.EnvConfig.UiConfigDisabled {
return &common.UiConfigDisabledError{}
}
// Count of keysAndValues must be even
if len(keysAndValues)%2 != 0 {
return errors.New("invalid number of arguments received")
}
// Start the transaction
tx, err := s.updateAppConfigStartTransaction(ctx)
if err != nil {
return err
}
defer func() {
tx.Rollback()
}()
// From here onwards, we know we are the only process/goroutine with exclusive access to the config
// Re-load the config from the database to be sure we have the correct data
cfg, err := s.loadDbConfigInternal(ctx, tx)
if err != nil {
return fmt.Errorf("failed to reload config from database: %w", err)
}
defaultCfg := s.getDefaultDbConfig()
// Iterate through all the fields to update
// We update the in-memory data (in the cfg struct) and collect values to update in the database
// (Note the += 2, as we are iterating through key-value pairs)
dbUpdate := make([]model.AppConfigVariable, 0, len(keysAndValues)/2)
for i := 0; i < len(keysAndValues); i += 2 {
key := keysAndValues[i]
value := keysAndValues[i+1]
// Ensure that the field is valid
// We do this by grabbing the default value
var defaultValue string
defaultValue, err = defaultCfg.FieldByKey(key)
if err != nil {
return fmt.Errorf("invalid configuration key '%s': %w", key, err)
}
// Update the in-memory config value
// If the new value is an empty string, then we set the in-memory value to the default one
// Skip values that are internal only and can't be updated
if value == "" {
err = cfg.UpdateField(key, defaultValue, false)
} else {
err = cfg.UpdateField(key, value, false)
}
if err != nil {
return fmt.Errorf("failed to update in-memory config for key '%s': %w", key, err)
}
// We always save "value" which can be an empty string
dbUpdate = append(dbUpdate, model.AppConfigVariable{
Key: key,
Value: value,
})
}
// Update the values in the database
err = s.updateAppConfigUpdateDatabase(ctx, tx, &dbUpdate)
if err != nil {
return err
}
return s.LoadDbConfigFromDb()
// Commit the changes to the DB, then finally save the updated config in the object
err = tx.Commit().Error
if err != nil {
return fmt.Errorf("failed to commit transaction: %w", err)
}
s.dbConfig.Store(cfg)
return nil
}
func (s *AppConfigService) ListAppConfig(ctx context.Context, showAll bool) (configuration []model.AppConfigVariable, err error) {
if showAll {
err = s.db.
WithContext(ctx).
Find(&configuration).
Error
} else {
err = s.db.
WithContext(ctx).
Find(&configuration, "is_public = true").
Error
}
if err != nil {
return nil, err
}
for i := range configuration {
if common.EnvConfig.UiConfigDisabled {
// Set the value to the environment variable if the UI config is disabled
configuration[i].Value = s.getConfigVariableFromEnvironmentVariable(configuration[i].Key, configuration[i].DefaultValue)
} else if configuration[i].Value == "" && configuration[i].DefaultValue != "" {
// Set the value to the default value if it is empty
configuration[i].Value = configuration[i].DefaultValue
}
}
return configuration, nil
func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable {
return s.GetDbConfig().ToAppConfigVariableSlice(showAll)
}
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
@@ -314,161 +320,108 @@ func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multip
return &common.FileTypeNotSupportedError{}
}
// Delete the old image if it has a different file type
if fileType != oldImageType {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err
}
}
// Save the updated image
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
err = utils.SaveFile(uploadedFile, imagePath)
if err != nil {
return err
}
// Update the file type in the database
err = s.updateImageType(ctx, imageName, fileType)
if err != nil {
return err
}
return nil
}
// InitDbConfig creates the default configuration values in the database if they do not exist,
// updates existing configurations if they differ from the default, and deletes any configurations
// that are not in the default configuration.
func (s *AppConfigService) InitDbConfig(ctx context.Context) (err error) {
tx := s.db.Begin()
defer func() {
tx.Rollback()
}()
// Reflect to get the underlying value of DbConfig and its default configuration
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
defaultKeys := make(map[string]struct{})
// Iterate over the fields of DbConfig
for i := range defaultConfigReflectValue.NumField() {
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
defaultKeys[defaultConfigVar.Key] = struct{}{}
var storedConfigVar model.AppConfigVariable
err = tx.
WithContext(ctx).
First(&storedConfigVar, "key = ?", defaultConfigVar.Key).
Error
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the configuration does not exist, create it
err = tx.
WithContext(ctx).
Create(&defaultConfigVar).
Error
if err != nil {
return err
}
continue
} else if err != nil {
return err
}
// Update existing configuration if it differs from the default
if storedConfigVar.Type != defaultConfigVar.Type ||
storedConfigVar.IsPublic != defaultConfigVar.IsPublic ||
storedConfigVar.IsInternal != defaultConfigVar.IsInternal ||
storedConfigVar.DefaultValue != defaultConfigVar.DefaultValue {
// Set values
storedConfigVar.Type = defaultConfigVar.Type
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
storedConfigVar.DefaultValue = defaultConfigVar.DefaultValue
err = tx.
WithContext(ctx).
Save(&storedConfigVar).
Error
if err != nil {
return err
}
}
}
// Delete any configurations not in the default keys
var allConfigVars []model.AppConfigVariable
err = tx.
WithContext(ctx).
Find(&allConfigVars).
Error
if err != nil {
return err
}
for _, config := range allConfigVars {
if _, exists := defaultKeys[config.Key]; exists {
continue
}
err = tx.
WithContext(ctx).
Delete(&config).
Error
// Delete the old image if it has a different file type, then update the type in the database
if fileType != oldImageType {
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
err = os.Remove(oldImagePath)
if err != nil {
return err
}
}
// Commit the changes
err = tx.Commit().Error
if err != nil {
return err
}
// Update the file type in the database
err = s.UpdateAppConfigValues(ctx, imageName+"ImageType", fileType)
if err != nil {
return err
}
// Reload the configuration
err = s.LoadDbConfigFromDb()
if err != nil {
return err
}
return nil
}
// LoadDbConfigFromDb loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfigFromDb() error {
return s.db.Transaction(func(tx *gorm.DB) error {
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
// LoadDbConfig loads the configuration values from the database into the DbConfig struct.
func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
var dest *model.AppConfig
for i := range dbConfigReflectValue.NumField() {
dbConfigField := dbConfigReflectValue.Field(i)
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
var storedConfigVar model.AppConfigVariable
err := tx.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error
if err != nil {
return err
}
if common.EnvConfig.UiConfigDisabled {
storedConfigVar.Value = s.getConfigVariableFromEnvironmentVariable(currentConfigVar.Key, storedConfigVar.DefaultValue)
} else if storedConfigVar.Value == "" && storedConfigVar.DefaultValue != "" {
storedConfigVar.Value = storedConfigVar.DefaultValue
}
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
}
return nil
})
}
func (s *AppConfigService) getConfigVariableFromEnvironmentVariable(key, fallbackValue string) string {
environmentVariableName := utils.CamelCaseToScreamingSnakeCase(key)
if value, exists := os.LookupEnv(environmentVariableName); exists {
return value
// If the UI config is disabled, only load from the env
if common.EnvConfig.UiConfigDisabled {
dest, err = s.loadDbConfigFromEnv()
} else {
dest, err = s.loadDbConfigInternal(ctx, s.db)
}
if err != nil {
return err
}
return fallbackValue
// Update the value in the object
s.dbConfig.Store(dest)
return nil
}
func (s *AppConfigService) loadDbConfigFromEnv() (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Iterate through each field
rt := reflect.ValueOf(dest).Elem().Type()
rv := reflect.ValueOf(dest).Elem()
for i := range rt.NumField() {
field := rt.Field(i)
// Get the value of the key tag, taking only what's before the comma
// The env var name is the key converted to SCREAMING_SNAKE_CASE
key, _, _ := strings.Cut(field.Tag.Get("key"), ",")
envVarName := utils.CamelCaseToScreamingSnakeCase(key)
// Set the value if it's set
value, ok := os.LookupEnv(envVarName)
if ok {
rv.Field(i).FieldByName("Value").SetString(value)
}
}
return dest, nil
}
func (s *AppConfigService) loadDbConfigInternal(ctx context.Context, tx *gorm.DB) (*model.AppConfig, error) {
// First, start from the default configuration
dest := s.getDefaultDbConfig()
// Load all configuration values from the database
// This loads all values in a single shot
loaded := []model.AppConfigVariable{}
queryCtx, queryCancel := context.WithTimeout(ctx, 10*time.Second)
defer queryCancel()
err := tx.
WithContext(queryCtx).
Find(&loaded).Error
if err != nil {
return nil, fmt.Errorf("failed to load configuration from the database: %w", err)
}
// Iterate through all values loaded from the database
for _, v := range loaded {
// If the value is empty, it means we are using the default value
if v.Value == "" {
continue
}
// Find the field in the struct whose "key" tag matches, then update that
err = dest.UpdateField(v.Key, v.Value, false)
// We ignore the case of fields that don't exist, as there may be leftover data in the database
if err != nil && !errors.Is(err, model.AppConfigKeyNotFoundError{}) {
return nil, fmt.Errorf("failed to process config for key '%s': %w", v.Key, err)
}
}
return dest, nil
}

View File

@@ -0,0 +1,561 @@
package service
import (
"sync/atomic"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/stretchr/testify/require"
)
// NewTestAppConfigService is a function used by tests to create AppConfigService objects with pre-defined configuration values
func NewTestAppConfigService(config *model.AppConfig) *AppConfigService {
service := &AppConfigService{
dbConfig: atomic.Pointer[model.AppConfig]{},
}
service.dbConfig.Store(config)
return service
}
func TestLoadDbConfig(t *testing.T) {
t.Run("empty config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
// Load the config
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be equal to default config
require.Equal(t, service.GetDbConfig(), service.getDefaultDbConfig())
})
t.Run("loads value from config table", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Populate the config table with some initial values
err := db.
Create([]model.AppConfigVariable{
// Should be set to the default value because it's an empty string
{Key: "appName", Value: ""},
// Overrides default value
{Key: "sessionDuration", Value: "5"},
// Does not have a default value
{Key: "smtpHost", Value: "example"},
}).
Error
require.NoError(t, err)
// Load the config
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Values should match expected ones
expect := service.getDefaultDbConfig()
expect.SessionDuration.Value = "5"
expect.SmtpHost.Value = "example"
require.Equal(t, service.GetDbConfig(), expect)
})
t.Run("ignores unknown config keys", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Add an entry with a key that doesn't exist in the config struct
err := db.Create([]model.AppConfigVariable{
{Key: "__nonExistentKey", Value: "some value"},
{Key: "appName", Value: "TestApp"}, // This one should still be loaded
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// This should not fail, just ignore the unknown key
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
config := service.GetDbConfig()
require.Equal(t, "TestApp", config.AppName.Value)
})
t.Run("loading config multiple times", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Initial state
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "InitialApp"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "InitialApp", service.GetDbConfig().AppName.Value)
// Update the database value
err = db.Model(&model.AppConfigVariable{}).
Where("key = ?", "appName").
Update("value", "UpdatedApp").Error
require.NoError(t, err)
// Load the config again, it should reflect the updated value
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
require.Equal(t, "UpdatedApp", service.GetDbConfig().AppName.Value)
})
t.Run("loads config from env when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables for testing
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Enable UiConfigDisabled to load from env
common.EnvConfig.UiConfigDisabled = true
// Create database with config that should be ignored
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from env, not DB
config := service.GetDbConfig()
require.Equal(t, "EnvTest App", config.AppName.Value, "Should load appName from env")
require.Equal(t, "45", config.SessionDuration.Value, "Should load sessionDuration from env")
})
t.Run("ignores env vars when UiConfigDisabled is false", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Set environment variables that should be ignored
t.Setenv("APP_NAME", "EnvTest App")
t.Setenv("SESSION_DURATION", "45")
// Make sure UiConfigDisabled is false to load from DB
common.EnvConfig.UiConfigDisabled = false
// Create database with config values that should take precedence
db := newAppConfigTestDatabaseForTest(t)
err := db.Create([]model.AppConfigVariable{
{Key: "appName", Value: "DB App"},
{Key: "sessionDuration", Value: "120"},
}).Error
require.NoError(t, err)
service := &AppConfigService{
db: db,
}
// Load the config
err = service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Config should be loaded from DB, not env
config := service.GetDbConfig()
require.Equal(t, "DB App", config.AppName.Value, "Should load appName from DB, not env")
require.Equal(t, "120", config.SessionDuration.Value, "Should load sessionDuration from DB, not env")
})
}
func TestUpdateAppConfigValues(t *testing.T) {
t.Run("update single value", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update a single config value
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App")
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
// Verify database was updated
var dbValue model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&dbValue).Error
require.NoError(t, err)
require.Equal(t, "Test App", dbValue.Value)
})
t.Run("update multiple values", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Update multiple config values
err = service.UpdateAppConfigValues(
t.Context(),
"appName", "Test App",
"sessionDuration", "30",
"smtpHost", "mail.example.com",
)
require.NoError(t, err)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Test App", config.AppName.Value)
require.Equal(t, "30", config.SessionDuration.Value)
require.Equal(t, "mail.example.com", config.SmtpHost.Value)
// Verify database was updated
var count int64
db.Model(&model.AppConfigVariable{}).Count(&count)
require.Equal(t, int64(3), count)
var appName, sessionDuration, smtpHost model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Test App", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "30", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "mail.example.com", smtpHost.Value)
})
t.Run("empty value resets to default", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First change the value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "30")
require.NoError(t, err)
require.Equal(t, "30", service.GetDbConfig().SessionDuration.Value)
// Now set it to empty which should use default value
err = service.UpdateAppConfigValues(t.Context(), "sessionDuration", "")
require.NoError(t, err)
require.Equal(t, "60", service.GetDbConfig().SessionDuration.Value) // Default value from getDefaultDbConfig
})
t.Run("error with odd number of arguments", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with odd number of arguments
err = service.UpdateAppConfigValues(t.Context(), "appName", "Test App", "sessionDuration")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid number of arguments")
})
t.Run("error with invalid key", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update with invalid key
err = service.UpdateAppConfigValues(t.Context(), "nonExistentKey", "some value")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid configuration key")
})
}
func TestUpdateAppConfig(t *testing.T) {
t.Run("updates configuration values from DTO", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Create update DTO
input := dto.AppConfigUpdateDto{
AppName: "Updated App Name",
SessionDuration: "120",
SmtpHost: "smtp.example.com",
SmtpPort: "587",
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables
require.NotEmpty(t, updatedVars)
var foundAppName, foundSessionDuration, foundSmtpHost, foundSmtpPort bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Updated App Name", v.Value)
foundAppName = true
case "sessionDuration":
require.Equal(t, "120", v.Value)
foundSessionDuration = true
case "smtpHost":
require.Equal(t, "smtp.example.com", v.Value)
foundSmtpHost = true
case "smtpPort":
require.Equal(t, "587", v.Value)
foundSmtpPort = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
require.True(t, foundSmtpHost)
require.True(t, foundSmtpPort)
// Verify in-memory config was updated
config := service.GetDbConfig()
require.Equal(t, "Updated App Name", config.AppName.Value)
require.Equal(t, "120", config.SessionDuration.Value)
require.Equal(t, "smtp.example.com", config.SmtpHost.Value)
require.Equal(t, "587", config.SmtpPort.Value)
// Verify database was updated
var appName, sessionDuration, smtpHost, smtpPort model.AppConfigVariable
err = db.Where("key = ?", "appName").First(&appName).Error
require.NoError(t, err)
require.Equal(t, "Updated App Name", appName.Value)
err = db.Where("key = ?", "sessionDuration").First(&sessionDuration).Error
require.NoError(t, err)
require.Equal(t, "120", sessionDuration.Value)
err = db.Where("key = ?", "smtpHost").First(&smtpHost).Error
require.NoError(t, err)
require.Equal(t, "smtp.example.com", smtpHost.Value)
err = db.Where("key = ?", "smtpPort").First(&smtpPort).Error
require.NoError(t, err)
require.Equal(t, "587", smtpPort.Value)
})
t.Run("empty values reset to defaults", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config and modify some values
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First set some non-default values
err = service.UpdateAppConfigValues(t.Context(),
"appName", "Custom App",
"sessionDuration", "120",
)
require.NoError(t, err)
// Create update DTO with empty values to reset to defaults
input := dto.AppConfigUpdateDto{
AppName: "", // Should reset to default "Pocket ID"
SessionDuration: "", // Should reset to default "60"
}
// Update config
updatedVars, err := service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify returned updated variables (they should be empty strings in DB)
var foundAppName, foundSessionDuration bool
for _, v := range updatedVars {
switch v.Key {
case "appName":
require.Equal(t, "Pocket ID", v.Value) // Returns the default value
foundAppName = true
case "sessionDuration":
require.Equal(t, "60", v.Value) // Returns the default value
foundSessionDuration = true
}
}
require.True(t, foundAppName)
require.True(t, foundSessionDuration)
// Verify in-memory config was reset to defaults
config := service.GetDbConfig()
require.Equal(t, "Pocket ID", config.AppName.Value) // Default value
require.Equal(t, "60", config.SessionDuration.Value) // Default value
// Verify database was updated with empty values
for _, key := range []string{"appName", "sessionDuration"} {
var loaded model.AppConfigVariable
err = db.Where("key = ?", key).First(&loaded).Error
require.NoErrorf(t, err, "Failed to load DB value for key '%s'", key)
require.Emptyf(t, loaded.Value, "Loaded value for key '%s' is not empty", key)
}
})
t.Run("auto disables EmailOneTimeAccessEnabled when EmailLoginNotificationEnabled is false", func(t *testing.T) {
db := newAppConfigTestDatabaseForTest(t)
// Create a service with default config
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// First enable both settings
err = service.UpdateAppConfigValues(t.Context(),
"emailLoginNotificationEnabled", "true",
"emailOneTimeAccessEnabled", "true",
)
require.NoError(t, err)
// Verify both are enabled
config := service.GetDbConfig()
require.True(t, config.EmailLoginNotificationEnabled.IsTrue())
require.True(t, config.EmailOneTimeAccessEnabled.IsTrue())
// Now disable EmailLoginNotificationEnabled
input := dto.AppConfigUpdateDto{
EmailLoginNotificationEnabled: "false",
// Don't set EmailOneTimeAccessEnabled, it should be auto-disabled
}
// Update config
_, err = service.UpdateAppConfig(t.Context(), input)
require.NoError(t, err)
// Verify EmailOneTimeAccessEnabled was automatically disabled
config = service.GetDbConfig()
require.False(t, config.EmailLoginNotificationEnabled.IsTrue())
require.False(t, config.EmailOneTimeAccessEnabled.IsTrue())
})
t.Run("cannot update when UiConfigDisabled is true", func(t *testing.T) {
// Save the original state and restore it after the test
originalUiConfigDisabled := common.EnvConfig.UiConfigDisabled
defer func() {
common.EnvConfig.UiConfigDisabled = originalUiConfigDisabled
}()
// Disable UI config
common.EnvConfig.UiConfigDisabled = true
db := newAppConfigTestDatabaseForTest(t)
service := &AppConfigService{
db: db,
}
err := service.LoadDbConfig(t.Context())
require.NoError(t, err)
// Try to update config
_, err = service.UpdateAppConfig(t.Context(), dto.AppConfigUpdateDto{
AppName: "Should Not Update",
})
// Should get a UiConfigDisabledError
require.Error(t, err)
var uiConfigDisabledErr *common.UiConfigDisabledError
require.ErrorAs(t, err, &uiConfigDisabledErr)
})
}
// Implements gorm's logger.Writer interface
type testLoggerAdapter struct {
t *testing.T
}
func (l testLoggerAdapter) Printf(format string, args ...any) {
l.t.Logf(format, args...)
}
func newAppConfigTestDatabaseForTest(t *testing.T) *gorm.DB {
t.Helper()
// Get a name for this in-memory database that is specific to the test
dbName := utils.CreateSha256Hash(t.Name())
// Connect to a new in-memory SQL database
db, err := gorm.Open(
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
&gorm.Config{
TranslateError: true,
Logger: logger.New(
testLoggerAdapter{t: t},
logger.Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
ParameterizedQueries: false,
Colorful: false,
},
),
})
require.NoError(t, err, "Failed to connect to test database")
// Create the app_config_variables table
err = db.Exec(`
CREATE TABLE app_config_variables
(
key VARCHAR(100) NOT NULL PRIMARY KEY,
value TEXT NOT NULL
)
`).Error
require.NoError(t, err, "Failed to create test config table")
return db
}

View File

@@ -72,7 +72,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
}
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
if s.appConfigService.GetDbConfig().EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
// We use a background context here as this is running in a goroutine
//nolint:contextcheck
go func() {

View File

@@ -300,19 +300,15 @@ func (s *TestService) ResetApplicationImages() error {
return nil
}
func (s *TestService) ResetAppConfig() error {
// Reseed the config variables
if err := s.appConfigService.InitDbConfig(context.Background()); err != nil {
return err
}
// Reset all app config variables to their default values
if err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error; err != nil {
func (s *TestService) ResetAppConfig(ctx context.Context) error {
// Reset all app config variables to their default values in the database
err := s.db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&model.AppConfigVariable{}).Update("value", "").Error
if err != nil {
return err
}
// Reload the app config from the database after resetting the values
return s.appConfigService.LoadDbConfigFromDb()
return s.appConfigService.LoadDbConfig(ctx)
}
func (s *TestService) SetJWTKeys() {

View File

@@ -70,8 +70,10 @@ func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId stri
}
func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Address, template email.Template[V], tData *V) error {
dbConfig := srv.appConfigService.GetDbConfig()
data := &email.TemplateData[V]{
AppName: srv.appConfigService.DbConfig.AppName.Value,
AppName: dbConfig.AppName.Value,
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
Data: tData,
}
@@ -86,8 +88,8 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
c.AddHeader("Subject", template.Title(data))
c.AddAddressHeader("From", []email.Address{
{
Email: srv.appConfigService.DbConfig.SmtpFrom.Value,
Name: srv.appConfigService.DbConfig.AppName.Value,
Email: dbConfig.SmtpFrom.Value,
Name: dbConfig.AppName.Value,
},
})
c.AddAddressHeader("To", []email.Address{toEmail})
@@ -102,7 +104,7 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
// so we use the domain of the from address instead (the same as Thunderbird does)
// if the address does not have an @ (which would be unusual), we use hostname
from_address := srv.appConfigService.DbConfig.SmtpFrom.Value
from_address := dbConfig.SmtpFrom.Value
domain := ""
if strings.Contains(from_address, "@") {
domain = strings.Split(from_address, "@")[1]
@@ -152,16 +154,18 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
}
func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
port := srv.appConfigService.DbConfig.SmtpPort.Value
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
dbConfig := srv.appConfigService.GetDbConfig()
port := dbConfig.SmtpPort.Value
smtpAddress := dbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
InsecureSkipVerify: dbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: dbConfig.SmtpHost.Value,
}
// Connect to the SMTP server based on TLS setting
switch srv.appConfigService.DbConfig.SmtpTls.Value {
switch dbConfig.SmtpTls.Value {
case "none":
client, err = smtp.Dial(smtpAddress)
case "tls":
@@ -172,7 +176,7 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
tlsConfig,
)
default:
return nil, fmt.Errorf("invalid SMTP TLS setting: %s", srv.appConfigService.DbConfig.SmtpTls.Value)
return nil, fmt.Errorf("invalid SMTP TLS setting: %s", dbConfig.SmtpTls.Value)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
@@ -186,8 +190,8 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
}
// Set up the authentication if user or password are set
smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value
smtpUser := dbConfig.SmtpUser.Value
smtpPassword := dbConfig.SmtpPassword.Value
if smtpUser != "" || smtpPassword != "" {
// Authenticate with plain auth
@@ -223,7 +227,7 @@ func (srv *EmailService) sendHelloCommand(client *smtp.Client) error {
func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error {
// Set the sender
if err := client.Mail(srv.appConfigService.DbConfig.SmtpFrom.Value, nil); err != nil {
if err := client.Mail(srv.appConfigService.GetDbConfig().SmtpFrom.Value, nil); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}

View File

@@ -182,7 +182,7 @@ func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
Expiration(now.Add(s.appConfigService.DbConfig.SessionDuration.AsDurationMinutes())).
Expiration(now.Add(s.appConfigService.GetDbConfig().SessionDuration.AsDurationMinutes())).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Build()

View File

@@ -26,11 +26,9 @@ import (
)
func TestJwtService_Init(t *testing.T) {
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
})
t.Run("should generate new key when none exists", func(t *testing.T) {
// Create a temporary directory for the test
@@ -142,11 +140,9 @@ func TestJwtService_Init(t *testing.T) {
}
func TestJwtService_GetPublicJWK(t *testing.T) {
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
})
t.Run("returns public key when private key is initialized", func(t *testing.T) {
// Create a temporary directory for the test
@@ -276,11 +272,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
@@ -366,11 +360,9 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
t.Run("uses session duration from config", func(t *testing.T) {
// Create a JWT service with a different session duration
customMockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
},
}
customMockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "30"}, // 30 minutes
})
service := &JwtService{}
err := service.init(customMockConfig, tempDir)
@@ -567,11 +559,9 @@ func TestGenerateVerifyIdToken(t *testing.T) {
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL
@@ -900,11 +890,9 @@ func TestGenerateVerifyOauthAccessToken(t *testing.T) {
tempDir := t.TempDir()
// Initialize the JWT service with a mock AppConfigService
mockConfig := &AppConfigService{
DbConfig: &model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
},
}
mockConfig := NewTestAppConfigService(&model.AppConfig{
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
})
// Setup the environment variable required by the token verification
originalAppURL := common.EnvConfig.AppURL

View File

@@ -32,12 +32,15 @@ func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService
}
func (s *LdapService) createClient() (*ldap.Conn, error) {
if !s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
dbConfig := s.appConfigService.GetDbConfig()
if !dbConfig.LdapEnabled.IsTrue() {
return nil, fmt.Errorf("LDAP is not enabled")
}
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
ldapURL := dbConfig.LdapUrl.Value
skipTLSVerify := dbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: skipTLSVerify, //nolint:gosec
}))
@@ -46,8 +49,8 @@ func (s *LdapService) createClient() (*ldap.Conn, error) {
}
// Bind as service account
bindDn := s.appConfigService.DbConfig.LdapBindDn.Value
bindPassword := s.appConfigService.DbConfig.LdapBindPassword.Value
bindDn := dbConfig.LdapBindDn.Value
bindPassword := dbConfig.LdapBindPassword.Value
err = client.Bind(bindDn, bindPassword)
if err != nil {
return nil, fmt.Errorf("failed to bind to LDAP: %w", err)
@@ -83,6 +86,8 @@ func (s *LdapService) SyncAll(ctx context.Context) error {
//nolint:gocognit
func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
dbConfig := s.appConfigService.GetDbConfig()
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -90,19 +95,20 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
}
defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value
nameAttribute := s.appConfigService.DbConfig.LdapAttributeGroupName.Value
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeGroupUniqueIdentifier.Value
groupMemberOfAttribute := s.appConfigService.DbConfig.LdapAttributeGroupMember.Value
filter := s.appConfigService.DbConfig.LdapUserGroupSearchFilter.Value
searchAttrs := []string{
nameAttribute,
uniqueIdentifierAttribute,
groupMemberOfAttribute,
dbConfig.LdapAttributeGroupName.Value,
dbConfig.LdapAttributeGroupUniqueIdentifier.Value,
dbConfig.LdapAttributeGroupMember.Value,
}
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{})
searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserGroupSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq)
if err != nil {
return fmt.Errorf("failed to query LDAP: %w", err)
@@ -114,11 +120,11 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
for _, value := range result.Entries {
var membersUserId []string
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute)
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
// Skip groups without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute)
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeGroupUniqueIdentifier.Value)
continue
}
@@ -129,7 +135,7 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
tx.WithContext(ctx).Where("ldap_id = ?", ldapId).First(&databaseGroup)
// Get group members and add to the correct Group
groupMembers := value.GetAttributeValues(groupMemberOfAttribute)
groupMembers := value.GetAttributeValues(dbConfig.LdapAttributeGroupMember.Value)
for _, member := range groupMembers {
// Normal output of this would be CN=username,ou=people,dc=example,dc=com
// Splitting at the "=" and "," then just grabbing the username for that string
@@ -151,9 +157,9 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
}
syncGroup := dto.UserGroupCreateDto{
Name: value.GetAttributeValue(nameAttribute),
FriendlyName: value.GetAttributeValue(nameAttribute),
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute),
Name: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
FriendlyName: value.GetAttributeValue(dbConfig.LdapAttributeGroupName.Value),
LdapID: value.GetAttributeValue(dbConfig.LdapAttributeGroupUniqueIdentifier.Value),
}
if databaseGroup.ID == "" {
@@ -214,6 +220,8 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB) error {
//nolint:gocognit
func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
dbConfig := s.appConfigService.GetDbConfig()
// Setup LDAP connection
client, err := s.createClient()
if err != nil {
@@ -221,30 +229,27 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
}
defer client.Close()
baseDN := s.appConfigService.DbConfig.LdapBase.Value
uniqueIdentifierAttribute := s.appConfigService.DbConfig.LdapAttributeUserUniqueIdentifier.Value
usernameAttribute := s.appConfigService.DbConfig.LdapAttributeUserUsername.Value
emailAttribute := s.appConfigService.DbConfig.LdapAttributeUserEmail.Value
firstNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserFirstName.Value
lastNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserLastName.Value
profilePictureAttribute := s.appConfigService.DbConfig.LdapAttributeUserProfilePicture.Value
adminGroupAttribute := s.appConfigService.DbConfig.LdapAttributeAdminGroup.Value
filter := s.appConfigService.DbConfig.LdapUserSearchFilter.Value
searchAttrs := []string{
"memberOf",
"sn",
"cn",
uniqueIdentifierAttribute,
usernameAttribute,
emailAttribute,
firstNameAttribute,
lastNameAttribute,
profilePictureAttribute,
dbConfig.LdapAttributeUserUniqueIdentifier.Value,
dbConfig.LdapAttributeUserUsername.Value,
dbConfig.LdapAttributeUserEmail.Value,
dbConfig.LdapAttributeUserFirstName.Value,
dbConfig.LdapAttributeUserLastName.Value,
dbConfig.LdapAttributeUserProfilePicture.Value,
}
// Filters must start and finish with ()!
searchReq := ldap.NewSearchRequest(baseDN, ldap.ScopeWholeSubtree, 0, 0, 0, false, filter, searchAttrs, []ldap.Control{})
searchReq := ldap.NewSearchRequest(
dbConfig.LdapBase.Value,
ldap.ScopeWholeSubtree,
0, 0, 0, false,
dbConfig.LdapUserSearchFilter.Value,
searchAttrs,
[]ldap.Control{},
)
result, err := client.Search(searchReq)
if err != nil {
@@ -255,11 +260,11 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
ldapUserIDs := make(map[string]bool)
for _, value := range result.Entries {
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute)
ldapId := value.GetAttributeValue(dbConfig.LdapAttributeUserUniqueIdentifier.Value)
// Skip users without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute)
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", dbConfig.LdapAttributeUserUniqueIdentifier.Value)
continue
}
@@ -272,16 +277,16 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
// Check if user is admin by checking if they are in the admin group
isAdmin := false
for _, group := range value.GetAttributeValues("memberOf") {
if strings.Contains(group, adminGroupAttribute) {
if strings.Contains(group, dbConfig.LdapAttributeAdminGroup.Value) {
isAdmin = true
}
}
newUser := dto.UserCreateDto{
Username: value.GetAttributeValue(usernameAttribute),
Email: value.GetAttributeValue(emailAttribute),
FirstName: value.GetAttributeValue(firstNameAttribute),
LastName: value.GetAttributeValue(lastNameAttribute),
Username: value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value),
Email: value.GetAttributeValue(dbConfig.LdapAttributeUserEmail.Value),
FirstName: value.GetAttributeValue(dbConfig.LdapAttributeUserFirstName.Value),
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
IsAdmin: isAdmin,
LdapID: ldapId,
}
@@ -299,7 +304,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB) error {
}
// Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
if pictureString := value.GetAttributeValue(dbConfig.LdapAttributeUserProfilePicture.Value); pictureString != "" {
if err := s.saveProfilePicture(ctx, databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %v", newUser.Username, err)
}

View File

@@ -748,7 +748,7 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
if slices.Contains(scopes, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue()
claims["email_verified"] = s.appConfigService.GetDbConfig().EmailsVerified.IsTrue()
}
if slices.Contains(scopes, "groups") {

View File

@@ -79,7 +79,7 @@ func (s *UserGroupService) Delete(ctx context.Context, id string) error {
}
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
if group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserGroupUpdateError{}
}
@@ -148,7 +148,7 @@ func (s *UserGroupService) updateInternal(ctx context.Context, id string, input
}
// Disallow updating the group if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
}

View File

@@ -188,7 +188,7 @@ func (s *UserService) deleteUserInternal(ctx context.Context, userID string, all
}
// Disallow deleting the user if it is an LDAP user and LDAP is enabled
if !allowLdapDelete && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
if !allowLdapDelete && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{}
}
@@ -278,7 +278,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
}
// Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.GetDbConfig().LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{}
}
@@ -314,7 +314,7 @@ func (s *UserService) RequestOneTimeAccessEmail(ctx context.Context, emailAddres
tx.Rollback()
}()
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
isDisabled := !s.appConfigService.GetDbConfig().EmailOneTimeAccessEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}

View File

@@ -26,7 +26,7 @@ type WebAuthnService struct {
func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *AuditLogService, appConfigService *AppConfigService) *WebAuthnService {
webauthnConfig := &webauthn.Config{
RPDisplayName: appConfigService.DbConfig.AppName.Value,
RPDisplayName: appConfigService.GetDbConfig().AppName.Value,
RPID: utils.GetHostnameFromURL(common.EnvConfig.AppURL),
RPOrigins: []string{common.EnvConfig.AppURL},
Timeouts: webauthn.TimeoutsConfig{
@@ -43,7 +43,13 @@ func NewWebAuthnService(db *gorm.DB, jwtService *JwtService, auditLogService *Au
},
}
wa, _ := webauthn.New(webauthnConfig)
return &WebAuthnService{db: db, webAuthn: wa, jwtService: jwtService, auditLogService: auditLogService, appConfigService: appConfigService}
return &WebAuthnService{
db: db,
webAuthn: wa,
jwtService: jwtService,
auditLogService: auditLogService,
appConfigService: appConfigService,
}
}
func (s *WebAuthnService) BeginRegistration(ctx context.Context, userID string) (*model.PublicKeyCredentialCreationOptions, error) {
@@ -314,5 +320,5 @@ func (s *WebAuthnService) UpdateCredential(ctx context.Context, userID, credenti
// updateWebAuthnConfig updates the WebAuthn configuration with the app name as it can change during runtime
func (s *WebAuthnService) updateWebAuthnConfig() {
s.webAuthn.Config.RPDisplayName = s.appConfigService.DbConfig.AppName.Value
s.webAuthn.Config.RPDisplayName = s.appConfigService.GetDbConfig().AppName.Value
}