diff --git a/management/server/account.go b/management/server/account.go index ef58f6545..7ce22f821 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -66,7 +66,7 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + GetOrCreateAccountIDByUser(ctx context.Context, userId, domain string) (string, error) GetAccount(ctx context.Context, accountID string) (*Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) @@ -1267,26 +1267,39 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { - for i := 0; i < 2; i++ { - accountId := xid.New().String() +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (string, error) { + var accountID string + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for i := 0; i < 2; i++ { + accountID = xid.New().String() + + exists, err := transaction.AccountExists(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to check account existence: %v", err) + return err + } + + if !exists { + if err = newAccountWithId(ctx, transaction, accountID, userID, domain); err != nil { + log.WithContext(ctx).Errorf("failed to create new account: %v", err) + return err + } + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountCreated, nil) + + return nil + } - _, err := am.Store.GetAccount(ctx, accountId) - statusErr, _ := status.FromError(err) - switch { - case err == nil: log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") - continue - case statusErr.Type() == status.NotFound: - newAccount := newAccountWithId(ctx, accountId, userID, domain) - am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) - return newAccount, nil - default: - return nil, err } + + return nil + }) + if err != nil { + return "", status.Errorf(status.Internal, "failed to create new account") } - return nil, status.Errorf(status.Internal, "error while creating new account") + return accountID, nil } func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { @@ -1400,15 +1413,15 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + accountID, err = am.GetOrCreateAccountIDByUser(ctx, userID, domain) if err != nil { return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, accountID); err != nil { return "", err } - return account.Id, nil + return accountID, nil } return "", err } @@ -1709,7 +1722,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx newCategoty = claims.DomainCategory } - return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) + return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategoty, &primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. @@ -1747,29 +1760,26 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai } lowerDomain := strings.ToLower(claims.Domain) + isPrimaryDomain := true - newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) + newAccountID, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { return "", err } - newAccount.Domain = lowerDomain - newAccount.DomainCategory = claims.DomainCategory - newAccount.IsDomainPrimaryAccount = true - - err = am.Store.SaveAccount(ctx, newAccount) + err = am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, newAccountID, lowerDomain, claims.DomainCategory, &isPrimaryDomain) if err != nil { return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccountID) if err != nil { return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccountID, activity.UserJoined, nil) - return newAccount.Id, nil + return newAccountID, nil } func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { @@ -2419,93 +2429,82 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) } -// addAllGroup to account object if it doesn't exist -func addAllGroup(account *Account) error { - if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ - ID: xid.New().String(), - Name: "All", - Issued: nbgroup.GroupIssuedAPI, - } - for _, peer := range account.Peers { - allGroup.Peers = append(allGroup.Peers, peer.ID) - } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} - - id := xid.New().String() - - defaultPolicy := &Policy{ - ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, - Enabled: true, - Sources: []string{allGroup.ID}, - Destinations: []string{allGroup.ID}, - Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, - }, - }, - } - - account.Policies = []*Policy{defaultPolicy} - } - return nil -} - -// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { +// newAccountWithId initializes a new Account instance with the provided account ID, user ID, and domain. +// It creates default settings and establishes an initial user, group, and policy. +func newAccountWithId(ctx context.Context, transaction Store, accountID, userID, domain string) error { log.WithContext(ctx).Debugf("creating new account") - network := NewNetwork() - peers := make(map[string]*nbpeer.Peer) - users := make(map[string]*User) - routes := make(map[route.ID]*route.Route) - setupKeys := map[string]*SetupKey{} - nameServersGroups := make(map[string]*nbdns.NameServerGroup) + acc := &Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + Network: NewNetwork(), + CreatedBy: userID, + Domain: domain, + DNSSettings: DNSSettings{ + DisabledManagementGroups: make([]string, 0), + }, + Settings: &Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, + Extra: &account.ExtraSettings{ + PeerApprovalEnabled: false, + IntegratedValidatorGroups: make([]string, 0), + }, + }, + } + if err := transaction.CreateAccount(ctx, LockingStrengthUpdate, acc); err != nil { + return err + } owner := NewOwnerUser(userID) owner.AccountID = accountID - users[userID] = owner - - dnsSettings := DNSSettings{ - DisabledManagementGroups: make([]string, 0), + if err := transaction.SaveUser(ctx, LockingStrengthUpdate, owner); err != nil { + return err } - log.WithContext(ctx).Debugf("created new account %s", accountID) - acc := &Account{ - Id: accountID, - CreatedAt: time.Now().UTC(), - SetupKeys: setupKeys, - Network: network, - Peers: peers, - Users: users, - CreatedBy: userID, - Domain: domain, - Routes: routes, - NameServerGroups: nameServersGroups, - DNSSettings: dnsSettings, - Settings: &Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: DefaultPeerLoginExpiration, - GroupsPropagationEnabled: true, - RegularUsersViewBlocked: true, + allGroup := &nbgroup.Group{ + ID: xid.New().String(), + AccountID: accountID, + Name: "All", + Issued: nbgroup.GroupIssuedAPI, + } + if err := transaction.SaveGroup(ctx, LockingStrengthUpdate, allGroup); err != nil { + return err + } - PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + policyID := xid.New().String() + defaultPolicy := &Policy{ + ID: policyID, + AccountID: accountID, + Name: DefaultPolicyName, + Description: DefaultPolicyDescription, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + PolicyID: policyID, + Name: DefaultRuleName, + Description: DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + }, }, } - - if err := addAllGroup(acc); err != nil { - log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + if err := transaction.CreatePolicy(ctx, LockingStrengthUpdate, defaultPolicy); err != nil { + return err } - return acc + + log.WithContext(ctx).Debugf("created new account %s", accountID) + + return nil } // extractJWTGroups extracts the group names from a JWT token's claims. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 00ba4fd59..9889552b8 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -22,9 +22,9 @@ import ( ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) - GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) - CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, + GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error) + GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) @@ -176,16 +176,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") } -// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface -func (am *MockAccountManager) GetOrCreateAccountByUser( +// GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface +func (am *MockAccountManager) GetOrCreateAccountIDByUser( ctx context.Context, userId, domain string, -) (*server.Account, error) { - if am.GetOrCreateAccountByUserFunc != nil { - return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) +) (string, error) { + if am.GetOrCreateAccountIDByUserFunc != nil { + return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain) } - return nil, status.Errorf( + return "", status.Errorf( codes.Unimplemented, - "method GetOrCreateAccountByUser is not implemented", + "method GetOrCreateAccountIDByUser is not implemented", ) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index d3c3730d6..1d9086a0f 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -332,24 +332,27 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a return nil } -func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error { accountCopy := Account{ - Domain: domain, - DomainCategory: category, - IsDomainPrimaryAccount: isPrimaryDomain, + Domain: domain, + DomainCategory: category, } - fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.Model(&Account{}). - Select(fieldsToUpdate). - Where(idQueryCondition, accountID). - Updates(&accountCopy) + fieldsToUpdate := []string{"domain", "domain_category"} + if isPrimaryDomain != nil { + accountCopy.IsDomainPrimaryAccount = *isPrimaryDomain + fieldsToUpdate = append(fieldsToUpdate, "is_domain_primary_account") + } + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select(fieldsToUpdate). + Where(idQueryCondition, accountID).Updates(&accountCopy) if result.Error != nil { - return result.Error + log.WithContext(ctx).Errorf("failed to update account domain attributes in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to update account domain attributes in store") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "account %s", accountID) + return status.NewAccountNotFoundError(accountID) } return nil @@ -1728,6 +1731,15 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength Locking return nil } +func (s *SqlStore) CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(&account) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save new account in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save new account in store") + } + return nil +} + // GetPATByHashedToken returns a PersonalAccessToken by its hashed token. func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) { var pat PersonalAccessToken diff --git a/management/server/store.go b/management/server/store.go index 900130273..a846eacd7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -61,9 +61,10 @@ type Store interface { GetTotalAccounts(ctx context.Context, lockStrength LockingStrength) (int64, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error - UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error + CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) @@ -83,7 +84,7 @@ type Store interface { GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error diff --git a/management/server/user.go b/management/server/user.go index 1639ec50f..ac4db48c5 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -953,8 +953,9 @@ func validateUserUpdate(groupsMap map[string]*nbgroup.Group, initiatorUser, oldU return nil } -// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { +// GetOrCreateAccountIDByUser returns the account ID for a given user ID. +// If no account exists for the user, it creates a new one using the specified domain. +func (am *DefaultAccountManager) GetOrCreateAccountIDByUser(ctx context.Context, userID, domain string) (string, error) { start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -962,34 +963,39 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u lowerDomain := strings.ToLower(domain) - account, err := am.Store.GetAccountByUser(ctx, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err = am.newAccount(ctx, userID, lowerDomain) + accountID, err = am.newAccount(ctx, userID, lowerDomain) if err != nil { - return nil, err - } - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err + return "", err } + return accountID, nil } else { // other error - return nil, err + return "", err } } - userObj := account.Users[userID] - - if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { - account.Domain = lowerDomain - err = am.Store.SaveAccount(ctx, account) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return nil, status.Errorf(status.Internal, "failed updating account with domain") + return err } - } - return account, nil + accDomain, accCategory, err := transaction.GetAccountDomainAndCategory(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + if lowerDomain != "" && accDomain != lowerDomain && user.Role == UserRoleOwner { + return transaction.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, lowerDomain, accCategory, nil) + } + + return nil + }) + + return accountID, err } // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return