initial implementation

This commit is contained in:
Ashley Mensah
2026-03-02 14:20:50 +01:00
parent 721aa41361
commit cc15f5cb03
11 changed files with 1346 additions and 31 deletions

View File

@@ -3445,6 +3445,80 @@ func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
// ListUsers returns all users across all accounts with decrypted sensitive fields.
func (s *SqlStore) ListUsers(ctx context.Context) ([]*types.User, error) {
var users []*types.User
if err := s.db.Find(&users).Error; err != nil {
return nil, status.Errorf(status.Internal, "failed to list users")
}
for _, user := range users {
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
log.WithContext(ctx).Errorf("failed to decrypt user data for user %s: %v", user.Id, err)
return nil, status.Errorf(status.Internal, "failed to decrypt user data")
}
}
return users, nil
}
// txDeferFKConstraints defers foreign key constraint checks for the duration of the transaction.
// MySQL is already handled by s.transaction (SET FOREIGN_KEY_CHECKS = 0).
func (s *SqlStore) txDeferFKConstraints(tx *gorm.DB) error {
switch s.storeEngine {
case types.PostgresStoreEngine:
return tx.Exec("SET CONSTRAINTS ALL DEFERRED").Error
case types.SqliteStoreEngine:
return tx.Exec("PRAGMA defer_foreign_keys = ON").Error
default:
return nil
}
}
// UpdateUserID re-keys a user's ID from oldUserID to newUserID, updating all FK references first,
// then the users.id primary key last. All updates happen in a single transaction.
func (s *SqlStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
type fkUpdate struct {
model any
column string
where string
}
updates := []fkUpdate{
{&types.PersonalAccessToken{}, "user_id", "user_id = ?"},
{&types.PersonalAccessToken{}, "created_by", "created_by = ?"},
{&nbpeer.Peer{}, "user_id", "user_id = ?"},
{&types.UserInviteRecord{}, "created_by", "created_by = ?"},
{&types.Account{}, "created_by", "created_by = ?"},
{&types.ProxyAccessToken{}, "created_by", "created_by = ?"},
{&types.Job{}, "triggered_by", "triggered_by = ?"},
{&types.PolicyRule{}, "authorized_user", "authorized_user = ?"},
{&accesslogs.AccessLogEntry{}, "user_id", "user_id = ?"},
}
err := s.transaction(func(tx *gorm.DB) error {
if err := s.txDeferFKConstraints(tx); err != nil {
return err
}
for _, u := range updates {
if err := tx.Model(u.model).Where(u.where, oldUserID).Update(u.column, newUserID).Error; err != nil {
return fmt.Errorf("update %s: %w", u.column, err)
}
}
if err := tx.Model(&types.User{}).Where(accountAndIDQueryCondition, accountID, oldUserID).Update("id", newUserID).Error; err != nil {
return fmt.Errorf("update users: %w", err)
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update user ID in the store: %s", err)
return status.Errorf(status.Internal, "failed to update user ID in store")
}
return nil
}
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
func (s *SqlStore) SetFieldEncrypt(enc *crypt.FieldEncrypt) {
s.fieldEncrypt = enc

View File

@@ -275,6 +275,11 @@ type Store interface {
// GetCustomDomainsCounts returns the total and validated custom domain counts.
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
// ListUsers returns all users across all accounts.
ListUsers(ctx context.Context) ([]*types.User, error)
// UpdateUserID re-keys a user's ID from oldUserID to newUserID, updating all foreign key references.
UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error
}
const (

View File

@@ -1109,21 +1109,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetAccountSettings mocks base method.
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
m.ctrl.T.Helper()
@@ -1288,6 +1273,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
}
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetDNSRecordByID mocks base method.
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
m.ctrl.T.Helper()
@@ -1872,22 +1873,6 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceTargetByTargetID", reflect.TypeOf((*MockStore)(nil).GetServiceTargetByTargetID), ctx, lockStrength, accountID, targetID)
}
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
@@ -1903,6 +1888,21 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetSetupKeyByID mocks base method.
func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) {
m.ctrl.T.Helper()
@@ -2231,6 +2231,21 @@ func (mr *MockStoreMockRecorder) ListFreeDomains(ctx, accountID interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListFreeDomains", reflect.TypeOf((*MockStore)(nil).ListFreeDomains), ctx, accountID)
}
// ListUsers mocks base method.
func (m *MockStore) ListUsers(ctx context.Context) ([]*types2.User, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListUsers", ctx)
ret0, _ := ret[0].([]*types2.User)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListUsers indicates an expected call of ListUsers.
func (mr *MockStoreMockRecorder) ListUsers(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListUsers", reflect.TypeOf((*MockStore)(nil).ListUsers), ctx)
}
// MarkAccountPrimary mocks base method.
func (m *MockStore) MarkAccountPrimary(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
@@ -2776,6 +2791,20 @@ func (mr *MockStoreMockRecorder) UpdateService(ctx, service interface{}) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockStore)(nil).UpdateService), ctx, service)
}
// UpdateUserID mocks base method.
func (m *MockStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateUserID", ctx, accountID, oldUserID, newUserID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateUserID indicates an expected call of UpdateUserID.
func (mr *MockStoreMockRecorder) UpdateUserID(ctx, accountID, oldUserID, newUserID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserID", reflect.TypeOf((*MockStore)(nil).UpdateUserID), ctx, accountID, oldUserID, newUserID)
}
// UpdateZone mocks base method.
func (m *MockStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
m.ctrl.T.Helper()