diff --git a/management/server/account.go b/management/server/account.go index 3303e9dee..e6c8ba661 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1611,3 +1611,113 @@ func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, ma func (am *DefaultAccountManager) GetStore() store.Store { return am.Store } + +// Creates account by private domain. +// Expects domain value to be a valid and a private dns domain. +func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) { + cancel := am.Store.AcquireGlobalLock(ctx) + defer cancel() + + domain = strings.ToLower(domain) + + count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain) + if err != nil { + return nil, err + } + + if count > 0 { + return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists") + } + + // retry twice for new ID clashes + for range 2 { + accountId := xid.New().String() + + exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) + if err != nil || exists { + continue + } + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[route.ID]*route.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + + newAccount := &types.Account{ + Id: accountId, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + // @todo check if using the MSP owner id here is ok + CreatedBy: initiatorId, + Domain: domain, + DomainCategory: types.PrivateCategory, + IsDomainPrimaryAccount: false, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, + }, + } + + if err := newAccount.AddAllGroup(); err != nil { + return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain") + } + + if err := am.Store.SaveAccount(ctx, newAccount); err != nil { + log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err) + return nil, err + } + + am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil) + return newAccount, nil + } + + return nil, status.Errorf(status.Internal, "failed to create new account by private domain") +} + +func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { + account, err := am.Store.GetAccount(ctx, accountId) + if err != nil { + return nil, err + } + + if account.IsDomainPrimaryAccount { + return account, nil + } + + // additional check to ensure there is only one account for this domain at the time of update + count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain) + if err != nil { + return nil, err + } + + if count > 1 { + return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain") + } + + account.IsDomainPrimaryAccount = true + + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err) + return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id) + } + + return account, nil +} diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 37c50267b..f482f2f51 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -111,4 +111,6 @@ type Manager interface { BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store + CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) + UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 2690f0c27..b53533cf9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3153,3 +3153,51 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { }) } } + +func Test_CreateAccountByPrivateDomain(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + ctx := context.Background() + initiatorId := "test-user" + domain := "example.com" + + account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.NoError(t, err) + + assert.False(t, account.IsDomainPrimaryAccount) + assert.Equal(t, domain, account.Domain) + assert.Equal(t, types.PrivateCategory, account.DomainCategory) + assert.Equal(t, initiatorId, account.CreatedBy) + assert.Equal(t, 1, len(account.Groups)) + assert.Equal(t, 0, len(account.Users)) + assert.Equal(t, 0, len(account.SetupKeys)) + + // retry should fail + _, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.Error(t, err) +} + +func Test_UpdateToPrimaryAccount(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + ctx := context.Background() + initiatorId := "test-user" + domain := "example.com" + + account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain) + assert.NoError(t, err) + assert.False(t, account.IsDomainPrimaryAccount) + + // retry should fail + account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) + assert.NoError(t, err) + assert.True(t, account.IsDomainPrimaryAccount) +} diff --git a/management/server/event.go b/management/server/event.go index 7bcf8ae25..12d9eb9da 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -188,6 +188,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) if err != nil { + // @todo consider logging continue } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index cb8d598f8..3fb2180ae 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -112,6 +112,8 @@ type MockAccountManager struct { DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) GetStoreFunc func() store.Store + CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error) + UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { @@ -847,3 +849,17 @@ func (am *MockAccountManager) GetStore() store.Store { } return nil } + +func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) { + if am.CreateAccountByPrivateDomainFunc != nil { + return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain) + } + return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented") +} + +func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { + if am.UpdateToPrimaryAccountFunc != nil { + return am.UpdateToPrimaryAccountFunc(ctx, accountId) + } + return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented") +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index cf6665665..a7e479e4e 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2190,3 +2190,17 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength return &peer, nil } + +func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) { + var count int64 + result := s.db.Model(&types.Account{}). + Where("domain = ? AND domain_category = ?", + strings.ToLower(domain), types.PrivateCategory, + ).Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to count accounts by private domain %s: %s", domain, result.Error) + return 0, status.Errorf(status.Internal, "failed to count accounts by private domain") + } + + return count, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 9ff0c5636..b2ae3c37d 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -69,6 +69,7 @@ type Store interface { DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error + CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)