Handle new account creation directly within the store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-11-04 14:43:14 +03:00
parent 4ad00e784c
commit e513e51e9f
5 changed files with 122 additions and 121 deletions

View File

@@ -69,7 +69,7 @@ func cacheEntryExpiration() time.Duration {
} }
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) GetOrCreateAccountIDByUser(ctx context.Context, userId, domain string) (string, error)
GetAccount(ctx context.Context, accountID string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
@@ -1268,25 +1268,28 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co
// newAccount creates a new Account with a generated ID and generated default setup keys. // newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error // If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (string, error) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
accountId := xid.New().String() accountID := xid.New().String()
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountId) exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error while checking account existence: %v", err) log.WithContext(ctx).Errorf("error while checking account existence: %v", err)
return nil, err return "", err
} }
if !exists { if !exists {
newAccount := newAccountWithId(ctx, accountId, userID, domain) if err = newAccountWithId(ctx, am.Store, accountID, userID, domain); err != nil {
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) return "", err
return newAccount, nil }
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountCreated, nil)
return accountID, nil
} }
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
} }
return nil, status.Errorf(status.Internal, "error while creating new account") return "", status.Errorf(status.Internal, "error while creating new account")
} }
func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
@@ -1400,15 +1403,15 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) accountID, err = am.GetOrCreateAccountIDByUser(ctx, userID, domain)
if err != nil { if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
} }
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { if err = am.addAccountIDToIDPAppMeta(ctx, userID, accountID); err != nil {
return "", err return "", err
} }
return account.Id, nil return accountID, nil
} }
return "", err return "", err
} }
@@ -1705,7 +1708,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
newCategory = claims.DomainCategory newCategory = claims.DomainCategory
} }
return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategory, primaryDomain) return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategory, &primaryDomain)
} }
// handleExistingUserAccount handles existing User accounts and update its domain attributes. // handleExistingUserAccount handles existing User accounts and update its domain attributes.
@@ -1743,29 +1746,26 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
} }
lowerDomain := strings.ToLower(claims.Domain) lowerDomain := strings.ToLower(claims.Domain)
isPrimaryDomain := true
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) newAccountID, err := am.newAccount(ctx, claims.UserId, lowerDomain)
if err != nil { if err != nil {
return "", err return "", err
} }
newAccount.Domain = lowerDomain err = am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, newAccountID, lowerDomain, claims.DomainCategory, &isPrimaryDomain)
newAccount.DomainCategory = claims.DomainCategory
newAccount.IsDomainPrimaryAccount = true
err = am.Store.SaveAccount(ctx, newAccount)
if err != nil { if err != nil {
return "", err return "", err
} }
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccountID)
if err != nil { if err != nil {
return "", err return "", err
} }
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccountID, activity.UserJoined, nil)
return newAccount.Id, nil return newAccountID, nil
} }
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
@@ -2395,23 +2395,56 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
} }
// addAllGroup to account object if it doesn't exist // newAccountWithId initializes a new Account instance with the provided account ID, user ID, and domain.
func addAllGroup(account *Account) error { // It creates default settings and establishes an initial user, group, and policy.
if len(account.Groups) == 0 { func newAccountWithId(ctx context.Context, store Store, accountID, userID, domain string) error {
log.WithContext(ctx).Debugf("creating new account")
return store.ExecuteInTransaction(ctx, func(transaction Store) error {
acc := &Account{
Id: accountID,
CreatedAt: time.Now().UTC(),
Network: NewNetwork(),
CreatedBy: userID,
Domain: domain,
DNSSettings: DNSSettings{
DisabledManagementGroups: make([]string, 0),
},
Settings: &Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
},
}
if err := transaction.CreateAccount(ctx, LockingStrengthUpdate, acc); err != nil {
return fmt.Errorf("failed to create account: %w", err)
}
owner := NewOwnerUser(userID)
owner.AccountID = accountID
if err := transaction.SaveUser(ctx, LockingStrengthUpdate, owner); err != nil {
return fmt.Errorf("failed to save account owner: %w", err)
}
allGroup := &nbgroup.Group{ allGroup := &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID,
Name: "All", Name: "All",
Issued: nbgroup.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
} }
for _, peer := range account.Peers { if err := transaction.SaveGroup(ctx, LockingStrengthUpdate, allGroup); err != nil {
allGroup.Peers = append(allGroup.Peers, peer.ID) return fmt.Errorf("failed to save group All: %w", err)
} }
account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup}
id := xid.New().String() id := xid.New().String()
defaultPolicy := &Policy{ defaultPolicy := &Policy{
ID: id, ID: id,
AccountID: accountID,
Name: DefaultPolicyName, Name: DefaultPolicyName,
Description: DefaultPolicyDescription, Description: DefaultPolicyDescription,
Enabled: true, Enabled: true,
@@ -2429,59 +2462,14 @@ func addAllGroup(account *Account) error {
}, },
}, },
} }
if err := transaction.SavePolicy(ctx, LockingStrengthUpdate, defaultPolicy); err != nil {
account.Policies = []*Policy{defaultPolicy} return fmt.Errorf("failed to save default policy: %w", err)
} }
return nil
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account {
log.WithContext(ctx).Debugf("creating new account")
network := NewNetwork()
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*User)
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
owner := NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
log.WithContext(ctx).Debugf("created new account %s", accountID) log.WithContext(ctx).Debugf("created new account %s", accountID)
acc := &Account{ return nil
Id: accountID, })
CreatedAt: time.Now().UTC(),
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userID,
Domain: domain,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
},
}
if err := addAllGroup(acc); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
} }
// extractJWTGroups extracts the group names from a JWT token's claims. // extractJWTGroups extracts the group names from a JWT token's claims.

View File

@@ -22,7 +22,7 @@ import (
) )
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
@@ -177,16 +177,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
} }
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountIDByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser( func (am *MockAccountManager) GetOrCreateAccountIDByUser(
ctx context.Context, userId, domain string, ctx context.Context, userId, domain string,
) (*server.Account, error) { ) (string, error) {
if am.GetOrCreateAccountByUserFunc != nil { if am.GetOrCreateAccountIDByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain)
} }
return nil, status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetOrCreateAccountByUser is not implemented", "method GetOrCreateAccountIDByUser is not implemented",
) )
} }

View File

@@ -143,6 +143,15 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
return unlock return unlock
} }
func (s *SqlStore) CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Create(&account)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save new account in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save new account in store")
}
return nil
}
func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -324,14 +333,18 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a
return nil return nil
} }
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error { func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error {
accountCopy := Account{ accountCopy := Account{
Domain: domain, Domain: domain,
DomainCategory: category, DomainCategory: category,
IsDomainPrimaryAccount: isPrimaryDomain,
} }
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} fieldsToUpdate := []string{"domain", "domain_category"}
if isPrimaryDomain != nil {
accountCopy.IsDomainPrimaryAccount = *isPrimaryDomain
fieldsToUpdate = append(fieldsToUpdate, "is_domain_primary_account")
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select(fieldsToUpdate). Select(fieldsToUpdate).
Where(idQueryCondition, accountID). Where(idQueryCondition, accountID).

View File

@@ -61,9 +61,10 @@ type Store interface {
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)

View File

@@ -927,43 +927,42 @@ func validateUserUpdate(groupsMap map[string]*nbgroup.Group, initiatorUser, oldU
return nil return nil
} }
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist // GetOrCreateAccountIDByUser returns the account ID for a given user ID.
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { // If no account exists for the user, it creates a new one using the specified domain.
start := time.Now() func (am *DefaultAccountManager) GetOrCreateAccountIDByUser(ctx context.Context, userID, domain string) (string, error) {
unlock := am.Store.AcquireGlobalLock(ctx)
defer unlock()
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), userID)
lowerDomain := strings.ToLower(domain) lowerDomain := strings.ToLower(domain)
accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
account, err := am.Store.GetAccountByUser(ctx, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err = am.newAccount(ctx, userID, lowerDomain) accountID, err = am.newAccount(ctx, userID, lowerDomain)
if err != nil { if err != nil {
return nil, err return "", err
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
} }
return accountID, nil
} else { } else {
// other error // other error
return nil, err return "", err
} }
} }
userObj := account.Users[userID] user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
account.Domain = lowerDomain
err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return nil, status.Errorf(status.Internal, "failed updating account with domain") return "", err
}
accDomain, accCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", err
}
if lowerDomain != "" && accDomain != lowerDomain && user.Role == UserRoleOwner {
err = am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, lowerDomain, accCategory, nil)
if err != nil {
return "", err
} }
} }
return account, nil return "", nil
} }
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return