diff --git a/management/server/account.go b/management/server/account.go index e6e77a58b..a771050c9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1474,7 +1474,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims // if domain already has a primary account, add regular user if domainAcc != nil { account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) + account.Users[claims.UserId] = NewRegularUser(claims.UserId, account.Id) err = am.Store.SaveAccount(account) if err != nil { return nil, err @@ -1863,9 +1863,10 @@ func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { func addAllGroup(account *Account) error { if len(account.Groups) == 0 { allGroup := &nbgroup.Group{ - ID: xid.New().String(), - Name: "All", - Issued: nbgroup.GroupIssuedAPI, + ID: xid.New().String(), + Name: "All", + Issued: nbgroup.GroupIssuedAPI, + AccountID: account.Id, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) @@ -1909,7 +1910,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { routes := make(map[string]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + users[userID] = NewOwnerUser(userID, accountID) dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 9d174d0be..1c4d6518b 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -54,7 +54,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := server.NewAdminUser("test_user", "account_id") sr := func(v string) *string { return &v } br := func(v bool) *bool { return &v } diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index a2f65a521..1018bb080 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -34,7 +34,7 @@ var testingDNSSettingsAccount = &server.Account{ Id: testDNSSettingsAccountID, Domain: "hotmail.com", Users: map[string]*server.User{ - testDNSSettingsUserID: server.NewAdminUser("test_user"), + testDNSSettingsUserID: server.NewAdminUser("test_user", "account_id"), }, DNSSettings: baseExistingDNSSettings, } diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 4cfad922b..46fe2989f 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -196,7 +196,7 @@ func TestEvents_GetEvents(t *testing.T) { }, } accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := server.NewAdminUser("test_user", "account_id") events := generateEvents(accountID, adminUser.Id) handler := initEventsTestData(accountID, adminUser, events...) diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 226711002..37406ec24 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -42,7 +42,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + user := server.NewAdminUser("test_user", "account_id") return &server.Account{ Id: claims.AccountId, Users: map[string]*server.User{ diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 3d74b848c..88ba95931 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -124,7 +124,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - adminUser := server.NewAdminUser("test_user") + adminUser := server.NewAdminUser("test_user", "account_id") p := initGroupTestData(adminUser, group) for _, tc := range tt { @@ -246,7 +246,7 @@ func TestWriteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") + adminUser := server.NewAdminUser("test_user", "account_id") p := initGroupTestData(adminUser) for _, tc := range tt { @@ -324,7 +324,7 @@ func TestDeleteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") + adminUser := server.NewAdminUser("test_user", "account_id") p := initGroupTestData(adminUser) for _, tc := range tt { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index e1fabb198..2797a192b 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -32,7 +32,7 @@ var testingNSAccount = &server.Account{ Id: testNSGroupAccountID, Domain: "hotmail.com", Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + "test_user": server.NewAdminUser("test_user", "account_id"), }, } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index e43c4375e..1f3909f55 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -59,7 +59,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { return "netbird.selfhosted" }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + user := server.NewAdminUser("test_user", "account_id") return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 74e682854..904ca25fa 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -45,7 +45,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { return nil }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + user := server.NewAdminUser("test_user", "account_id") return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 733fdf7d2..d5477078b 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -67,7 +67,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return accountPostureChecks, nil }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + user := server.NewAdminUser("test_user", "account_id") return &server.Account{ Id: claims.AccountId, Users: map[string]*server.User{ diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index c02292f2a..81bdf7d0e 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -75,7 +75,7 @@ var testingAccount = &server.Account{ }, }, Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + "test_user": server.NewAdminUser("test_user", "account_id"), }, } diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index ebbd5954f..e42b0b669 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -97,7 +97,7 @@ func TestSetupKeysHandlers(t *testing.T) { defaultSetupKey := server.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID - adminUser := server.NewAdminUser("test_user") + adminUser := server.NewAdminUser("test_user", "account_id") newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, server.SetupKeyUnlimitedUsage, true) diff --git a/management/server/scheduler.go b/management/server/scheduler.go index 356348056..224441d9e 100644 --- a/management/server/scheduler.go +++ b/management/server/scheduler.go @@ -95,18 +95,18 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne case <-ticker.C: select { case <-cancel: - log.Debugf("scheduled job %s was canceled, stop timer", ID) + log.Tracef("scheduled job %s was canceled, stop timer", ID) ticker.Stop() return default: - log.Debugf("time to do a scheduled job %s", ID) + log.Tracef("time to do a scheduled job %s", ID) } runIn, reschedule := job() if !reschedule { wm.mu.Lock() defer wm.mu.Unlock() delete(wm.jobs, ID) - log.Debugf("job %s is not scheduled to run again", ID) + log.Tracef("job %s is not scheduled to run again", ID) ticker.Stop() return } @@ -115,7 +115,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne ticker.Reset(runIn) } case <-cancel: - log.Debugf("job %s was canceled, stopping timer", ID) + log.Tracef("job %s was canceled, stopping timer", ID) ticker.Stop() return } diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index e6a9c8467..5b4fc92d7 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "path/filepath" + "reflect" "runtime" "strings" "sync" @@ -134,72 +135,139 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { return unlock } +func batchInsert(records interface{}, batchSize int, tx *gorm.DB) error { + // Get the reflect.Value of the records slice + v := reflect.ValueOf(records) + if v.Kind() != reflect.Slice { + return fmt.Errorf("provided input is not a slice") + } + + // Insert records in batches + for i := 0; i < v.Len(); i += batchSize { + end := i + batchSize + if end > v.Len() { + end = v.Len() + } + // Use reflect.Slice to get a slice of the records for the current batch + batch := v.Slice(i, end).Interface() + if err := tx.CreateInBatches(batch, end-i).Debug().Error; err != nil { + return err + } + } + return nil +} + func (s *SqliteStore) SaveAccount(account *Account) error { start := time.Now() - for _, key := range account.SetupKeys { - account.SetupKeysG = append(account.SetupKeysG, *key) + // operate over a fresh copy as we will modify its fields + accCopy := account.Copy() + accCopy.SetupKeysG = make([]SetupKey, 0, len(accCopy.SetupKeys)) + for _, key := range accCopy.SetupKeys { + //we need an explicit reference to the account for gorm + key.AccountID = accCopy.Id + accCopy.SetupKeysG = append(accCopy.SetupKeysG, *key) } - for id, peer := range account.Peers { + accCopy.PeersG = make([]nbpeer.Peer, 0, len(accCopy.Peers)) + for id, peer := range accCopy.Peers { peer.ID = id - account.PeersG = append(account.PeersG, *peer) + //we need an explicit reference to the account for gorm + peer.AccountID = accCopy.Id + accCopy.PeersG = append(accCopy.PeersG, *peer) } - for id, user := range account.Users { + accCopy.UsersG = make([]User, 0, len(accCopy.Users)) + for id, user := range accCopy.Users { user.Id = id + //we need an explicit reference to the account for gorm + user.AccountID = accCopy.Id + user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs)) for id, pat := range user.PATs { pat.ID = id user.PATsG = append(user.PATsG, *pat) } - account.UsersG = append(account.UsersG, *user) + accCopy.UsersG = append(accCopy.UsersG, *user) } - for id, group := range account.Groups { + accCopy.GroupsG = make([]nbgroup.Group, 0, len(accCopy.Groups)) + for id, group := range accCopy.Groups { group.ID = id - account.GroupsG = append(account.GroupsG, *group) + //we need an explicit reference to the account for gorm + group.AccountID = accCopy.Id + accCopy.GroupsG = append(accCopy.GroupsG, *group) } - for id, route := range account.Routes { + accCopy.RoutesG = make([]route.Route, 0, len(accCopy.Routes)) + for id, route := range accCopy.Routes { route.ID = id - account.RoutesG = append(account.RoutesG, *route) + //we need an explicit reference to the account for gorm + route.AccountID = accCopy.Id + accCopy.RoutesG = append(accCopy.RoutesG, *route) } - for id, ns := range account.NameServerGroups { + accCopy.NameServerGroupsG = make([]nbdns.NameServerGroup, 0, len(accCopy.NameServerGroups)) + for id, ns := range accCopy.NameServerGroups { ns.ID = id - account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) + //we need an explicit reference to the account for gorm + ns.AccountID = accCopy.Id + accCopy.NameServerGroupsG = append(accCopy.NameServerGroupsG, *ns) } err := s.db.Transaction(func(tx *gorm.DB) error { - result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) + result := tx.Select(clause.Associations).Delete(accCopy.Policies, "account_id = ?", accCopy.Id) if result.Error != nil { return result.Error } - result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + result = tx.Select(clause.Associations).Delete(accCopy.UsersG, "account_id = ?", accCopy.Id) if result.Error != nil { return result.Error } - result = tx.Select(clause.Associations).Delete(account) + result = tx.Select(clause.Associations).Delete(accCopy) if result.Error != nil { return result.Error } result = tx. Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}).Create(account) + Clauses(clause.OnConflict{UpdateAll: true}). + Omit("PeersG", "GroupsG", "UsersG", "SetupKeysG", "RoutesG", "NameServerGroupsG"). + Create(accCopy) if result.Error != nil { return result.Error } - return nil + + const batchSize = 500 + err := batchInsert(accCopy.PeersG, batchSize, tx) + if err != nil { + return err + } + err = batchInsert(accCopy.UsersG, batchSize, tx) + if err != nil { + return err + } + err = batchInsert(accCopy.GroupsG, batchSize, tx) + if err != nil { + return err + } + err = batchInsert(accCopy.RoutesG, batchSize, tx) + if err != nil { + return err + } + err = batchInsert(accCopy.SetupKeysG, batchSize, tx) + if err != nil { + return err + } + return batchInsert(accCopy.NameServerGroupsG, batchSize, tx) }) took := time.Since(start) if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds()) + log.Debugf("took %d ms to persist an account %s to the SQLite store", took.Milliseconds(), accCopy.Id) return err } @@ -207,6 +275,19 @@ func (s *SqliteStore) SaveAccount(account *Account) error { func (s *SqliteStore) DeleteAccount(account *Account) error { start := time.Now() + account.UsersG = make([]User, 0, len(account.Users)) + for id, user := range account.Users { + user.Id = id + //we need an explicit reference to an account as it is missing for some reason + user.AccountID = account.Id + user.PATsG = make([]PersonalAccessToken, 0, len(user.PATs)) + for id, pat := range user.PATs { + pat.ID = id + user.PATsG = append(user.PATsG, *pat) + } + account.UsersG = append(account.UsersG, *user) + } + err := s.db.Transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index e43a0cd9a..88ab54913 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -2,7 +2,12 @@ package server import ( "fmt" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + route2 "github.com/netbirdio/netbird/route" + "math/rand" "net" + "net/netip" "path/filepath" "runtime" "testing" @@ -29,6 +34,141 @@ func TestSqlite_NewStore(t *testing.T) { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") } } +func TestSqlite_SaveAccount_Large(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStore(t) + + account := newAccountWithId("account_id", "testuser", "") + groupALL, err := account.GetGroupAll() + if err != nil { + t.Fatal(err) + } + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + const numPerAccount = 2000 + for n := 0; n < numPerAccount; n++ { + netIP := randomIPv4() + peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + + peer := &nbpeer.Peer{ + ID: peerID, + Key: peerID, + SetupKey: "", + IP: netIP, + Name: peerID, + DNSLabel: peerID, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + account.Peers[peerID] = peer + group, _ := account.GetGroupAll() + group.Peers = append(group.Peers, peerID) + user := &User{ + Id: fmt.Sprintf("%s-user-%d", account.Id, n), + AccountID: account.Id, + } + account.Users[user.Id] = user + route := &route2.Route{ + ID: fmt.Sprintf("network-id-%d", n), + Description: "base route", + NetID: fmt.Sprintf("network-id-%d", n), + Network: netip.MustParsePrefix(netIP.String() + "/24"), + NetworkType: route2.IPv4Network, + Metric: 9999, + Masquerade: false, + Enabled: true, + Groups: []string{groupALL.ID}, + } + account.Routes[route.ID] = route + + group = &nbgroup.Group{ + ID: fmt.Sprintf("group-id-%d", n), + AccountID: account.Id, + Name: fmt.Sprintf("group-id-%d", n), + Issued: "api", + Peers: nil, + } + account.Groups[group.ID] = group + + nameserver := &nbdns.NameServerGroup{ + ID: fmt.Sprintf("nameserver-id-%d", n), + AccountID: account.Id, + Name: fmt.Sprintf("nameserver-id-%d", n), + Description: "", + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, + Groups: []string{group.ID}, + Primary: false, + Domains: nil, + Enabled: false, + SearchDomainsEnabled: false, + } + account.NameServerGroups[nameserver.ID] = nameserver + + setupKey := GenerateDefaultSetupKey() + account.SetupKeys[setupKey.Key] = setupKey + } + + err = store.SaveAccount(account) + require.NoError(t, err) + + if len(store.GetAllAccounts()) != 1 { + t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") + } + + a, err := store.GetAccount(account.Id) + if a == nil { + t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) + } + + if a != nil && len(a.Policies) != 1 { + t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies)) + } + + if a != nil && len(a.Policies[0].Rules) != 1 { + t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules)) + return + } + + if a != nil && len(a.Peers) != numPerAccount { + t.Errorf("expecting Account to have %d peers stored after SaveAccount(), got %d", + numPerAccount, len(a.Peers)) + return + } + + if a != nil && len(a.Users) != numPerAccount+1 { + t.Errorf("expecting Account to have %d users stored after SaveAccount(), got %d", + numPerAccount+1, len(a.Users)) + return + } + + if a != nil && len(a.Routes) != numPerAccount { + t.Errorf("expecting Account to have %d routes stored after SaveAccount(), got %d", + numPerAccount, len(a.Routes)) + return + } + + if a != nil && len(a.NameServerGroups) != numPerAccount { + t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d", + numPerAccount, len(a.NameServerGroups)) + return + } + + if a != nil && len(a.NameServerGroups) != numPerAccount { + t.Errorf("expecting Account to have %d NameServerGroups stored after SaveAccount(), got %d", + numPerAccount, len(a.NameServerGroups)) + return + } + + if a != nil && len(a.SetupKeys) != numPerAccount+1 { + t.Errorf("expecting Account to have %d SetupKeys stored after SaveAccount(), got %d", + numPerAccount+1, len(a.SetupKeys)) + return + } +} func TestSqlite_SaveAccount(t *testing.T) { if runtime.GOOS == "windows" { @@ -48,6 +188,12 @@ func TestSqlite_SaveAccount(t *testing.T) { Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } + admin := account.Users["testuser"] + admin.PATs = map[string]*PersonalAccessToken{"testtoken": { + ID: "testtoken", + Name: "test token", + HashedToken: "hashed token", + }} err := store.SaveAccount(account) require.NoError(t, err) @@ -110,7 +256,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { store := newSqliteStore(t) testUserID := "testuser" - user := NewAdminUser(testUserID) + user := NewAdminUser(testUserID, "account_id") user.PATs = map[string]*PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", @@ -393,3 +539,12 @@ func newAccount(store Store, id int) error { return store.SaveAccount(account) } + +func randomIPv4() net.IP { + rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, 4) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return net.IP(b) +} diff --git a/management/server/user.go b/management/server/user.go index b955c4058..9d3055262 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -180,9 +180,11 @@ func (u *User) Copy() *User { } // NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { +func NewUser(ID string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, + accountID string) *User { return &User{ - Id: id, + Id: ID, + AccountID: accountID, Role: role, IsServiceUser: isServiceUser, NonDeletable: nonDeletable, @@ -194,22 +196,26 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se } // NewRegularUser creates a new user with role UserRoleUser -func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +func NewRegularUser(ID, accountID string) *User { + return NewUser(ID, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, + accountID) } // NewAdminUser creates a new user with role UserRoleAdmin -func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) +func NewAdminUser(ID, accountID string) *User { + return NewUser(ID, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, + accountID) } // NewOwnerUser creates a new user with role UserRoleOwner -func NewOwnerUser(id string) *User { - return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) +func NewOwnerUser(ID, accountID string) *User { + return NewUser(ID, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, + accountID) } // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { +func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, + serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -231,7 +237,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs } newUserID := uuid.New().String() - newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) + newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI, accountID) log.Debugf("New User: %v", newUser) account.Users[newUserID] = newUser diff --git a/management/server/user_test.go b/management/server/user_test.go index c92f87e6c..82869ad09 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -679,8 +679,8 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") + account.Users["normal_user1"] = NewRegularUser("normal_user1", mockAccountID) + account.Users["normal_user2"] = NewRegularUser("normal_user2", mockAccountID) err := store.SaveAccount(account) if err != nil { @@ -760,7 +760,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI, mockAccountID) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) @@ -844,10 +844,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestUser_IsAdmin(t *testing.T) { - user := NewAdminUser(mockUserID) + user := NewAdminUser(mockUserID, mockAccountID) assert.True(t, user.HasAdminPower()) - user = NewRegularUser(mockUserID) + user = NewRegularUser(mockUserID, mockAccountID) assert.False(t, user.HasAdminPower()) } @@ -1055,8 +1055,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) + account.Users[regularUserID] = NewRegularUser(regularUserID, account.Id) + account.Users[adminUserID] = NewAdminUser(adminUserID, account.Id) account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(account) if err != nil {