Merge branch 'main' into handle-existing-domain-user

# Conflicts:
#	management/server/account.go
#	management/server/account_test.go
This commit is contained in:
bcmmbaga
2025-08-12 13:31:40 +03:00
631 changed files with 36478 additions and 12888 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"net/netip"
"os"
"reflect"
"strconv"
@@ -14,7 +15,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -373,7 +374,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@@ -398,7 +399,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain)
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@@ -640,7 +641,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain)
_ = newAccountWithId(context.Background(), "", userId, domain, false)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
@@ -782,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return
}
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -793,7 +794,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain)
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
@@ -899,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
}
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err)
assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err)
assert.Len(t, pats, 0)
}
@@ -1159,7 +1160,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Name: "GroupA",
Peers: []string{},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1194,7 +1195,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}()
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1208,6 +1209,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
// Ensure that we do not receive an update message before the policy is deleted
time.Sleep(time.Second)
select {
case <-updMsg:
t.Logf("received addPeer update message before policy deletion")
default:
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1232,11 +1241,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
AccountID: account.Id,
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1284,7 +1294,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Name: "GroupA",
Peers: []string{peer1.ID, peer3.ID},
}
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@@ -1335,11 +1345,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}, true)
})
require.NoError(t, err, "failed to save group")
@@ -1664,9 +1674,10 @@ func TestAccount_Copy(t *testing.T) {
},
Groups: map[string]*types.Group{
"group1": {
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
ID: "group1",
Peers: []string{"peer1"},
Resources: []types.Resource{},
GroupPeers: []types.GroupPeer{},
},
},
Policies: []*types.Policy{
@@ -1775,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings)
@@ -1805,9 +1816,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected")
account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1825,11 +1837,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
@@ -1856,15 +1868,13 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{}
wg.Add(2)
wg.Add(1)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
@@ -1919,9 +1929,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1935,6 +1946,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second)
@@ -1950,15 +1962,16 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -1967,12 +1980,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
}
@@ -2604,6 +2619,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
}
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -2611,11 +2627,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
account := &types.Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
},
Groups: map[string]*types.Group{
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
@@ -2639,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
@@ -2653,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})
@@ -2667,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
})
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
claims := nbcontext.UserAuth{
UserId: "user1",
@@ -2688,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
})
@@ -2706,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@@ -2720,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@@ -2734,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
@@ -2752,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2767,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
})
@@ -2875,7 +2891,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, err
}
@@ -3135,11 +3151,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 7, 20, 10, 80},
{"Small", 50, 5, 7, 20, 5, 80},
{"Medium", 500, 100, 5, 40, 30, 140},
{"Large", 5000, 200, 80, 120, 140, 390},
{"Small single", 50, 10, 7, 20, 10, 80},
{"Medium single", 500, 10, 5, 40, 20, 85},
{"Small single", 50, 10, 7, 20, 6, 80},
{"Medium single", 500, 10, 5, 40, 15, 85},
{"Large 5", 5000, 15, 80, 120, 80, 200},
}
@@ -3198,7 +3214,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
}
}
func Test_CreateAccountByPrivateDomain(t *testing.T) {
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -3209,9 +3225,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
@@ -3220,9 +3237,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
assert.Equal(t, 0, len(account.Users))
assert.Equal(t, 0, len(account.SetupKeys))
// retry should fail
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.Error(t, err)
// should return a new account because the previous one is not primary
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created2)
assert.False(t, account2.IsDomainPrimaryAccount)
assert.Equal(t, domain, account2.Domain)
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
assert.Equal(t, initiatorId, account2.CreatedBy)
assert.Equal(t, 1, len(account2.Groups))
assert.Equal(t, 0, len(account2.Users))
assert.Equal(t, 0, len(account2.SetupKeys))
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
assert.Error(t, err, "should not be able to update a second account to primary")
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
@@ -3236,14 +3269,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
// retry should fail
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, created2)
assert.True(t, account.IsDomainPrimaryAccount)
assert.Equal(t, account.Id, account2.Id)
}
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
@@ -3296,6 +3336,123 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
})
}
func TestPropagateUserGroupMemberships(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, peer2)
require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
})
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err)
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
})
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID)
require.NoError(t, manager.Store.SaveUser(ctx, user))
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
Name: "Group1 Policy",
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"group1"},
Destinations: []string{"group2"},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
}
})
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
})
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err)
user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
}
})
}
func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
testCases := []struct {
name string
@@ -3339,3 +3496,141 @@ func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
})
}
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
require.NoError(t, err)
require.NotNil(t, onboarding)
assert.Equal(t, account.Id, onboarding.AccountID)
assert.Equal(t, true, onboarding.OnboardingFlowPending)
assert.Equal(t, true, onboarding.SignupFormPending)
if onboarding.UpdatedAt.IsZero() {
t.Errorf("Onboarding was not retrieved from the store")
}
})
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
account.Id = "with-zero-onboarding"
account.Onboarding = types.AccountOnboarding{}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
require.NoError(t, err)
require.NotNil(t, onboarding)
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
require.Error(t, err, "should return error when onboarding is not set")
})
}
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
require.NoError(t, err)
onboarding := &types.AccountOnboarding{
OnboardingFlowPending: true,
SignupFormPending: true,
}
t.Run("update onboarding with no change", func(t *testing.T) {
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
require.NoError(t, err)
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
if updated.UpdatedAt.IsZero() {
t.Errorf("Onboarding was updated in the store")
}
})
onboarding.OnboardingFlowPending = false
onboarding.SignupFormPending = false
t.Run("update onboarding", func(t *testing.T) {
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
require.NoError(t, err)
require.NotNil(t, updated)
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
})
t.Run("update onboarding with no onboarding", func(t *testing.T) {
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
require.NoError(t, err)
})
}
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
require.NoError(t, err, "unable to add peer2")
t.Run("update peer IP successfully", func(t *testing.T) {
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get account")
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
require.NoError(t, err, "unable to allocate new IP")
newAddr := netip.MustParseAddr(newIP.String())
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
require.NoError(t, err, "unable to update peer IP")
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
require.NoError(t, err, "unable to get updated peer")
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
})
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
currentAddr := netip.MustParseAddr(peer1.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
require.NoError(t, err, "updating with same IP should not error")
})
t.Run("update peer IP with collision should fail", func(t *testing.T) {
peer2Addr := netip.MustParseAddr(peer2.IP.String())
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
require.Error(t, err, "should fail when IP is already assigned")
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
})
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
invalidAddr := netip.MustParseAddr("192.168.1.100")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
require.Error(t, err, "should fail when IP is outside network range")
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
})
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
newAddr := netip.MustParseAddr("100.64.0.101")
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
require.Error(t, err, "should fail with invalid peer ID")
})
}