mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-05 08:36:37 +00:00
Merge branch 'main' into handle-existing-domain-user
# Conflicts: # management/server/account.go # management/server/account_test.go
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -14,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -373,7 +374,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -398,7 +399,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -640,7 +641,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
@@ -782,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
|
||||
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||
|
||||
@@ -793,7 +794,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -899,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||
}
|
||||
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1")
|
||||
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId)
|
||||
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
}
|
||||
@@ -1159,7 +1160,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
Name: "GroupA",
|
||||
Peers: []string{},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1194,7 +1195,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
}()
|
||||
|
||||
group.Peers = []string{peer1.ID, peer2.ID, peer3.ID}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.UpdateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1208,6 +1209,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
// Ensure that we do not receive an update message before the policy is deleted
|
||||
time.Sleep(time.Second)
|
||||
select {
|
||||
case <-updMsg:
|
||||
t.Logf("received addPeer update message before policy deletion")
|
||||
default:
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1232,11 +1241,12 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
AccountID: account.Id,
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1284,7 +1294,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer3.ID},
|
||||
}
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil {
|
||||
if err := manager.CreateGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1335,11 +1345,11 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
|
||||
ID: "groupA",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}, true)
|
||||
})
|
||||
|
||||
require.NoError(t, err, "failed to save group")
|
||||
|
||||
@@ -1664,9 +1674,10 @@ func TestAccount_Copy(t *testing.T) {
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
Resources: []types.Resource{},
|
||||
GroupPeers: []types.GroupPeer{},
|
||||
},
|
||||
},
|
||||
Policies: []*types.Policy{
|
||||
@@ -1775,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.NotNil(t, settings)
|
||||
@@ -1805,9 +1816,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1825,11 +1837,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1856,15 +1868,13 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
wg.Add(1)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
@@ -1919,9 +1929,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
@@ -1935,6 +1946,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
failed = waitTimeout(wg, time.Second)
|
||||
@@ -1950,15 +1962,16 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err, "unable to get account settings")
|
||||
|
||||
assert.False(t, settings.PeerLoginExpirationEnabled)
|
||||
@@ -1967,12 +1980,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
|
||||
}
|
||||
@@ -2604,6 +2619,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -2611,11 +2627,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "accountID",
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
|
||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
|
||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
|
||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1", IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"},
|
||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1", IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"},
|
||||
"peer3": {ID: "peer3", Key: "key3", UserID: "user1", IP: net.IP{3, 3, 3, 3}, DNSLabel: "peer3.domain.test"},
|
||||
"peer4": {ID: "peer4", Key: "key4", UserID: "user2", IP: net.IP{4, 4, 4, 4}, DNSLabel: "peer4.domain.test"},
|
||||
"peer5": {ID: "peer5", Key: "key5", UserID: "user2", IP: net.IP{5, 5, 5, 5}, DNSLabel: "peer5.domain.test"},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
"group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}},
|
||||
@@ -2639,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
||||
})
|
||||
@@ -2653,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
||||
})
|
||||
@@ -2667,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err := manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
|
||||
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
|
||||
assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
|
||||
|
||||
claims := nbcontext.UserAuth{
|
||||
UserId: "user1",
|
||||
@@ -2688,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@@ -2706,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2720,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||
})
|
||||
@@ -2734,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
|
||||
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, groups, 3, "new group3 should be added")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
})
|
||||
@@ -2752,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
||||
@@ -2767,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
err = manager.SyncUserJWTGroups(context.Background(), claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
||||
})
|
||||
@@ -2875,7 +2891,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3135,11 +3151,11 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 7, 20, 10, 80},
|
||||
{"Small", 50, 5, 7, 20, 5, 80},
|
||||
{"Medium", 500, 100, 5, 40, 30, 140},
|
||||
{"Large", 5000, 200, 80, 120, 140, 390},
|
||||
{"Small single", 50, 10, 7, 20, 10, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 20, 85},
|
||||
{"Small single", 50, 10, 7, 20, 6, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 15, 85},
|
||||
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
||||
}
|
||||
|
||||
@@ -3198,7 +3214,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -3209,9 +3225,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, created)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
|
||||
@@ -3220,9 +3237,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
|
||||
assert.Equal(t, 0, len(account.Users))
|
||||
assert.Equal(t, 0, len(account.SetupKeys))
|
||||
|
||||
// retry should fail
|
||||
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.Error(t, err)
|
||||
// should return a new account because the previous one is not primary
|
||||
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, created2)
|
||||
assert.False(t, account2.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account2.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
|
||||
assert.Equal(t, initiatorId, account2.CreatedBy)
|
||||
assert.Equal(t, 1, len(account2.Groups))
|
||||
assert.Equal(t, 0, len(account2.Users))
|
||||
assert.Equal(t, 0, len(account2.SetupKeys))
|
||||
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
|
||||
assert.Error(t, err, "should not be able to update a second account to primary")
|
||||
}
|
||||
|
||||
func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
@@ -3236,14 +3269,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, created)
|
||||
assert.False(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, domain, account.Domain)
|
||||
|
||||
// retry should fail
|
||||
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
|
||||
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, created2)
|
||||
assert.True(t, account.IsDomainPrimaryAccount)
|
||||
assert.Equal(t, account.Id, account2.Id)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
@@ -3296,6 +3336,123 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer1)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
|
||||
err = manager.Store.AddPeerToAccount(ctx, peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
})
|
||||
|
||||
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
|
||||
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.CreateGroup(ctx, group1))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group1.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
})
|
||||
|
||||
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
|
||||
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
|
||||
require.NoError(t, manager.Store.CreateGroup(ctx, group2))
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = append(user.AutoGroups, group2.ID)
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
|
||||
Name: "Group1 Policy",
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"group1"},
|
||||
Destinations: []string{"group2"},
|
||||
Bidirectional: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, groupsUpdated)
|
||||
assert.True(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
})
|
||||
|
||||
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
|
||||
require.NoError(t, err)
|
||||
|
||||
user.AutoGroups = []string{"group1"}
|
||||
require.NoError(t, manager.Store.SaveUser(ctx, user))
|
||||
|
||||
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, groupsUpdated)
|
||||
assert.False(t, groupChangesAffectPeers)
|
||||
|
||||
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
|
||||
require.NoError(t, err)
|
||||
for _, group := range groups {
|
||||
assert.Len(t, group.Peers, 2)
|
||||
assert.Contains(t, group.Peers, "peer1")
|
||||
assert.Contains(t, group.Peers, "peer2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -3339,3 +3496,141 @@ func TestDefaultAccountManager_AddNewUserToDomainAccount(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
assert.Equal(t, account.Id, onboarding.AccountID)
|
||||
assert.Equal(t, true, onboarding.OnboardingFlowPending)
|
||||
assert.Equal(t, true, onboarding.SignupFormPending)
|
||||
if onboarding.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was not retrieved from the store")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
|
||||
account.Id = "with-zero-onboarding"
|
||||
account.Onboarding = types.AccountOnboarding{}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
|
||||
require.Error(t, err, "should return error when onboarding is not set")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
onboarding := &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
}
|
||||
|
||||
t.Run("update onboarding with no change", func(t *testing.T) {
|
||||
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
if updated.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was updated in the store")
|
||||
}
|
||||
})
|
||||
|
||||
onboarding.OnboardingFlowPending = false
|
||||
onboarding.SignupFormPending = false
|
||||
|
||||
t.Run("update onboarding", func(t *testing.T) {
|
||||
updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending)
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
})
|
||||
|
||||
t.Run("update onboarding with no onboarding", func(t *testing.T) {
|
||||
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer1")
|
||||
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer2")
|
||||
|
||||
t.Run("update peer IP successfully", func(t *testing.T) {
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get account")
|
||||
|
||||
newIP, err := types.AllocatePeerIP(account.Network.Net, []net.IP{peer1.IP, peer2.IP})
|
||||
require.NoError(t, err, "unable to allocate new IP")
|
||||
|
||||
newAddr := netip.MustParseAddr(newIP.String())
|
||||
err = manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, newAddr)
|
||||
require.NoError(t, err, "unable to update peer IP")
|
||||
|
||||
updatedPeer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, userID)
|
||||
require.NoError(t, err, "unable to get updated peer")
|
||||
assert.Equal(t, newIP.String(), updatedPeer.IP.String(), "peer IP should be updated")
|
||||
})
|
||||
|
||||
t.Run("update peer IP with same IP should be no-op", func(t *testing.T) {
|
||||
currentAddr := netip.MustParseAddr(peer1.IP.String())
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, currentAddr)
|
||||
require.NoError(t, err, "updating with same IP should not error")
|
||||
})
|
||||
|
||||
t.Run("update peer IP with collision should fail", func(t *testing.T) {
|
||||
peer2Addr := netip.MustParseAddr(peer2.IP.String())
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, peer2Addr)
|
||||
require.Error(t, err, "should fail when IP is already assigned")
|
||||
assert.Contains(t, err.Error(), "already assigned", "error should mention IP collision")
|
||||
})
|
||||
|
||||
t.Run("update peer IP outside network range should fail", func(t *testing.T) {
|
||||
invalidAddr := netip.MustParseAddr("192.168.1.100")
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, peer1.ID, invalidAddr)
|
||||
require.Error(t, err, "should fail when IP is outside network range")
|
||||
assert.Contains(t, err.Error(), "not within the account network range", "error should mention network range")
|
||||
})
|
||||
|
||||
t.Run("update peer IP with invalid peer ID should fail", func(t *testing.T) {
|
||||
newAddr := netip.MustParseAddr("100.64.0.101")
|
||||
err := manager.UpdatePeerIP(context.Background(), accountID, userID, "invalid-peer-id", newAddr)
|
||||
require.Error(t, err, "should fail with invalid peer ID")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user