[management] create account by private domain (#3485)

This commit is contained in:
Pedro Maia Costa
2025-03-14 14:29:54 +00:00
committed by Pedro Costa
parent abaffbcc2d
commit 1df01a1ebf
7 changed files with 192 additions and 0 deletions

View File

@@ -1611,3 +1611,113 @@ func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, ma
func (am *DefaultAccountManager) GetStore() store.Store {
return am.Store
}
// Creates account by private domain.
// Expects domain value to be a valid and a private dns domain.
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel()
domain = strings.ToLower(domain)
count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
if err != nil {
return nil, err
}
if count > 0 {
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
}
// retry twice for new ID clashes
for range 2 {
accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId)
if err != nil || exists {
continue
}
network := types.NewNetwork()
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
dnsSettings := types.DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
newAccount := &types.Account{
Id: accountId,
CreatedAt: time.Now().UTC(),
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
// @todo check if using the MSP owner id here is ok
CreatedBy: initiatorId,
Domain: domain,
DomainCategory: types.PrivateCategory,
IsDomainPrimaryAccount: false,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
},
}
if err := newAccount.AddAllGroup(); err != nil {
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
}
if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
return nil, err
}
am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
}
return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
}
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return nil, err
}
if account.IsDomainPrimaryAccount {
return account, nil
}
// additional check to ensure there is only one account for this domain at the time of update
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
if err != nil {
return nil, err
}
if count > 1 {
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
}
return account, nil
}

View File

@@ -111,4 +111,6 @@ type Manager interface {
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
}

View File

@@ -3153,3 +3153,51 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
})
}
}
func Test_CreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
assert.Equal(t, initiatorId, account.CreatedBy)
assert.Equal(t, 1, len(account.Groups))
assert.Equal(t, 0, len(account.Users))
assert.Equal(t, 0, len(account.SetupKeys))
// retry should fail
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.Error(t, err)
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, account.IsDomainPrimaryAccount)
// retry should fail
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
}

View File

@@ -188,6 +188,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
if err != nil {
// @todo consider logging
continue
}

View File

@@ -112,6 +112,8 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@@ -847,3 +849,17 @@ func (am *MockAccountManager) GetStore() store.Store {
}
return nil
}
func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
if am.CreateAccountByPrivateDomainFunc != nil {
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
}
func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
if am.UpdateToPrimaryAccountFunc != nil {
return am.UpdateToPrimaryAccountFunc(ctx, accountId)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented")
}

View File

@@ -2190,3 +2190,17 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
return &peer, nil
}
func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) {
var count int64
result := s.db.Model(&types.Account{}).
Where("domain = ? AND domain_category = ?",
strings.ToLower(domain), types.PrivateCategory,
).Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count accounts by private domain %s: %s", domain, result.Error)
return 0, status.Errorf(status.Internal, "failed to count accounts by private domain")
}
return count, nil
}

View File

@@ -69,6 +69,7 @@ type Store interface {
DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)