mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 23:56:38 +00:00
@@ -401,7 +401,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
store := newStore(t)
|
||||
|
||||
err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "account-1")
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -419,6 +426,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
|
||||
store.Close(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,27 +435,35 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
||||
err := newAccountWithId(context.Background(), store, accountID, userId, domain)
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
func TestAccountManager_GetOrCreateAccountIDByUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userID)
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
account, err := manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
|
||||
return
|
||||
@@ -669,15 +686,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountIDByUserID where the id is getting generated
|
||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get init account failed")
|
||||
|
||||
@@ -693,44 +707,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
require.Len(t, accountGroups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
|
||||
require.NoError(t, err, "failed to update account settings")
|
||||
|
||||
totalAccounts, err := manager.Store.GetTotalAccounts(context.Background())
|
||||
require.NoError(t, err, "failed to get total accounts")
|
||||
require.Equal(t, int64(1), totalAccounts, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings)
|
||||
require.NoError(t, err, "failed to update account settings")
|
||||
|
||||
totalAccounts, err := manager.Store.GetTotalAccounts(context.Background())
|
||||
require.NoError(t, err, "failed to get total accounts")
|
||||
require.Equal(t, int64(1), totalAccounts, "only one account should exist")
|
||||
|
||||
accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to check account existence")
|
||||
require.True(t, exists, "account should exist")
|
||||
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId)
|
||||
require.NoError(t, err, "failed to get account groups")
|
||||
require.Len(t, accountGroups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*group.Group{}
|
||||
for _, g := range account.Groups {
|
||||
for _, g := range accountGroups {
|
||||
groupsByNames[g.Name] = g
|
||||
}
|
||||
|
||||
@@ -746,27 +769,23 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
func TestAccountManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
|
||||
userPAT := &PersonalAccessToken{
|
||||
ID: "tokenId",
|
||||
UserID: "testuser",
|
||||
HashedToken: encodedHashedToken,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
|
||||
require.NoError(t, err, "failed to save PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -778,31 +797,27 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
||||
assert.Equal(t, "testuser", user.Id)
|
||||
assert.Equal(t, userPAT, pat)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
err := newAccountWithId(context.Background(), store, "account_id", "testuser", "")
|
||||
require.NoError(t, err, "failed to create account")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
account.Users["someUser"] = &User{
|
||||
Id: "someUser",
|
||||
PATs: map[string]*PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
|
||||
userPAT := &PersonalAccessToken{
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
LastUsed: time.Time{},
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT)
|
||||
require.NoError(t, err, "failed to save PAT")
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -813,11 +828,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero())
|
||||
userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID)
|
||||
require.NoError(t, err, "failed to get PAT")
|
||||
|
||||
assert.True(t, !userPAT.LastUsed.IsZero())
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
@@ -828,15 +842,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
if accountID == "" {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
account, err := manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
|
||||
}
|
||||
@@ -855,32 +869,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if account == nil {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId)
|
||||
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
|
||||
}
|
||||
accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account domain and category")
|
||||
require.Equal(t, domain, accDomain, "expected account domain to match")
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain)
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
|
||||
if account == nil {
|
||||
t.Fatalf("expected to get an account for a user %s", userId)
|
||||
}
|
||||
|
||||
if account != nil && account.Domain != domain {
|
||||
t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
|
||||
}
|
||||
accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account domain and category")
|
||||
require.Equal(t, domain, accDomain, "expected account domain to match")
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
@@ -912,12 +916,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
return am.Store.GetAccount(context.Background(), accountID)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
@@ -1164,23 +1167,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud")
|
||||
require.NoError(t, err, "failed to get or create account by user")
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err, "failed to get account network")
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
return
|
||||
}
|
||||
serial := network.CurrentSerial() // should be 0
|
||||
require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0")
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to generate private key")
|
||||
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
@@ -1188,16 +1186,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "failed to add peer")
|
||||
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "failed to get account")
|
||||
|
||||
if peer.Key != expectedPeerKey {
|
||||
t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key)
|
||||
|
||||
Reference in New Issue
Block a user