mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 23:26:41 +00:00
Refactor new account handling
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -66,7 +66,7 @@ func cacheEntryExpiration() time.Duration {
|
||||
}
|
||||
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
||||
GetOrCreateAccountIDByUser(ctx context.Context, userId, domain string) (string, error)
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
||||
@@ -1267,26 +1267,39 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co
|
||||
|
||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||
// If ID is already in use (due to collision) we try one more time before returning error
|
||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) {
|
||||
for i := 0; i < 2; i++ {
|
||||
accountId := xid.New().String()
|
||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (string, error) {
|
||||
var accountID string
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for i := 0; i < 2; i++ {
|
||||
accountID = xid.New().String()
|
||||
|
||||
exists, err := transaction.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to check account existence: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
if err = newAccountWithId(ctx, transaction, accountID, userID, domain); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create new account: %v", err)
|
||||
return err
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountCreated, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := am.Store.GetAccount(ctx, accountId)
|
||||
statusErr, _ := status.FromError(err)
|
||||
switch {
|
||||
case err == nil:
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.Internal, "failed to create new account")
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.Internal, "error while creating new account")
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error {
|
||||
@@ -1400,15 +1413,15 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
|
||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
accountID, err = am.GetOrCreateAccountIDByUser(ctx, userID, domain)
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, accountID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
return accountID, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
@@ -1709,7 +1722,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
newCategoty = claims.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategoty, &primaryDomain)
|
||||
}
|
||||
|
||||
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
|
||||
@@ -1747,29 +1760,26 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
isPrimaryDomain := true
|
||||
|
||||
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
newAccountID, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newAccount.Domain = lowerDomain
|
||||
newAccount.DomainCategory = claims.DomainCategory
|
||||
newAccount.IsDomainPrimaryAccount = true
|
||||
|
||||
err = am.Store.SaveAccount(ctx, newAccount)
|
||||
err = am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, newAccountID, lowerDomain, claims.DomainCategory, &isPrimaryDomain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccountID, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
return newAccountID, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
@@ -2419,93 +2429,82 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &nbgroup.Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||
}
|
||||
account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &Policy{
|
||||
ID: id,
|
||||
Name: DefaultRuleName,
|
||||
Description: DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: id,
|
||||
Name: DefaultRuleName,
|
||||
Description: DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Sources: []string{allGroup.ID},
|
||||
Destinations: []string{allGroup.ID},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.Policies = []*Policy{defaultPolicy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account {
|
||||
// newAccountWithId initializes a new Account instance with the provided account ID, user ID, and domain.
|
||||
// It creates default settings and establishes an initial user, group, and policy.
|
||||
func newAccountWithId(ctx context.Context, transaction Store, accountID, userID, domain string) error {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := NewNetwork()
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
setupKeys := map[string]*SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
acc := &Account{
|
||||
Id: accountID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Network: NewNetwork(),
|
||||
CreatedBy: userID,
|
||||
Domain: domain,
|
||||
DNSSettings: DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
},
|
||||
Settings: &Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
Extra: &account.ExtraSettings{
|
||||
PeerApprovalEnabled: false,
|
||||
IntegratedValidatorGroups: make([]string, 0),
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := transaction.CreateAccount(ctx, LockingStrengthUpdate, acc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
owner := NewOwnerUser(userID)
|
||||
owner.AccountID = accountID
|
||||
users[userID] = owner
|
||||
|
||||
dnsSettings := DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
if err := transaction.SaveUser(ctx, LockingStrengthUpdate, owner); err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("created new account %s", accountID)
|
||||
|
||||
acc := &Account{
|
||||
Id: accountID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userID,
|
||||
Domain: domain,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
Settings: &Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
allGroup := &nbgroup.Group{
|
||||
ID: xid.New().String(),
|
||||
AccountID: accountID,
|
||||
Name: "All",
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
if err := transaction.SaveGroup(ctx, LockingStrengthUpdate, allGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
|
||||
policyID := xid.New().String()
|
||||
defaultPolicy := &Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: DefaultPolicyName,
|
||||
Description: DefaultPolicyDescription,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: xid.New().String(),
|
||||
PolicyID: policyID,
|
||||
Name: DefaultRuleName,
|
||||
Description: DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Sources: []string{allGroup.ID},
|
||||
Destinations: []string{allGroup.ID},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := addAllGroup(acc); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
if err := transaction.CreatePolicy(ctx, LockingStrengthUpdate, defaultPolicy); err != nil {
|
||||
return err
|
||||
}
|
||||
return acc
|
||||
|
||||
log.WithContext(ctx).Debugf("created new account %s", accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||
|
||||
Reference in New Issue
Block a user