mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 12:16:39 +00:00
Merge branch 'main' into feature/port-forwarding
This commit is contained in:
@@ -68,7 +68,7 @@ type AccountManager interface {
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
||||
@@ -80,7 +80,7 @@ type AccountManager interface {
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
@@ -97,7 +97,7 @@ type AccountManager interface {
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
@@ -150,6 +150,7 @@ type AccountManager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -622,6 +623,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
if user.Role != types.UserRoleOwner {
|
||||
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
|
||||
}
|
||||
|
||||
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
for _, otherUser := range account.Users {
|
||||
if otherUser.IsServiceUser {
|
||||
continue
|
||||
@@ -631,13 +638,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
continue
|
||||
}
|
||||
|
||||
deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id)
|
||||
userInfo, ok := userInfosMap[otherUser.Id]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id)
|
||||
}
|
||||
|
||||
_, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||
if deleteUserErr != nil {
|
||||
return deleteUserErr
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteRegularUser(ctx, account, userID, userID)
|
||||
userInfo, ok := userInfosMap[userID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
|
||||
}
|
||||
|
||||
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
|
||||
return err
|
||||
@@ -694,20 +711,8 @@ func isNil(i idp.Manager) bool {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cachedAccount := &types.Account{
|
||||
Id: accountID,
|
||||
Users: make(map[string]*types.User),
|
||||
}
|
||||
for _, user := range accountUsers {
|
||||
cachedAccount.Users[user.Id] = user
|
||||
}
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -783,10 +788,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
|
||||
}
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) {
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
// ignore service users and users provisioned by integrations than are never logged in
|
||||
for _, user := range account.Users {
|
||||
for _, user := range accountUsers {
|
||||
if user.IsServiceUser {
|
||||
continue
|
||||
}
|
||||
@@ -795,8 +805,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
}
|
||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||
}
|
||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(ctx, users, account.Id)
|
||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
|
||||
userData, err := am.lookupCache(ctx, users, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -809,13 +819,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
|
||||
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
||||
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
||||
user, err := account.FindUser(userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := user.IntegrationReference.CacheKey(account.Id, userID)
|
||||
key := user.IntegrationReference.CacheKey(accountID, userID)
|
||||
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
|
||||
@@ -1055,9 +1065,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
usersMap := make(map[string]*types.User)
|
||||
usersMap[claims.UserId] = types.NewRegularUser(claims.UserId)
|
||||
err := am.Store.SaveUsers(domainAccountID, usersMap)
|
||||
newUser := types.NewRegularUser(claims.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1080,12 +1090,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1095,17 +1100,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||
// Our job is to just reload cache.
|
||||
go func() {
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
|
||||
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1114,33 +1119,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pat, ok := account.Users[user.Id].PATs[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
pat.LastUsed = util.ToPtr(time.Now().UTC())
|
||||
|
||||
return am.Store.SaveAccount(ctx, account)
|
||||
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
@@ -1148,52 +1127,64 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return user, pat, domain, category, nil
|
||||
}
|
||||
|
||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||
if len(token) != types.PATLength {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong length")
|
||||
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||
}
|
||||
|
||||
prefix := token[:len(types.PATPrefix)]
|
||||
if prefix != types.PATPrefix {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
return nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
}
|
||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum does not match")
|
||||
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
|
||||
|
||||
var user *types.User
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, nil, nil, fmt.Errorf("personal access token not found")
|
||||
}
|
||||
|
||||
return account, user, pat, nil
|
||||
return user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
@@ -1339,7 +1330,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
|
||||
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||
}
|
||||
|
||||
@@ -733,6 +733,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"tokenId": {
|
||||
ID: "tokenId",
|
||||
UserID: "someUser",
|
||||
HashedToken: encodedHashedToken,
|
||||
},
|
||||
},
|
||||
@@ -746,14 +747,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
||||
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, "account_id", account.Id)
|
||||
assert.Equal(t, "account_id", user.AccountID)
|
||||
assert.Equal(t, "someUser", user.Id)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
|
||||
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
@@ -3018,11 +3019,11 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 1, 5, 3, 19},
|
||||
{"Medium", 500, 100, 7, 22, 10, 90},
|
||||
{"Large", 5000, 200, 65, 110, 60, 240},
|
||||
{"Small", 50, 5, 1, 5, 3, 24},
|
||||
{"Medium", 500, 100, 7, 22, 10, 135},
|
||||
{"Large", 5000, 200, 65, 110, 60, 320},
|
||||
{"Small single", 50, 10, 1, 4, 3, 80},
|
||||
{"Medium single", 500, 10, 7, 13, 10, 37},
|
||||
{"Medium single", 500, 10, 7, 13, 10, 43},
|
||||
{"Large 5", 5000, 15, 65, 80, 60, 220},
|
||||
}
|
||||
|
||||
@@ -3087,8 +3088,8 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 2, 10, 3, 35},
|
||||
{"Medium", 500, 100, 5, 40, 20, 110},
|
||||
{"Large", 5000, 200, 60, 100, 120, 260},
|
||||
{"Medium", 500, 100, 5, 40, 20, 140},
|
||||
{"Large", 5000, 200, 60, 100, 120, 320},
|
||||
{"Small single", 50, 10, 2, 10, 5, 40},
|
||||
{"Medium single", 500, 10, 5, 40, 10, 60},
|
||||
{"Large 5", 5000, 15, 60, 100, 60, 180},
|
||||
@@ -3163,9 +3164,9 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
}{
|
||||
{"Small", 50, 5, 7, 20, 10, 80},
|
||||
{"Medium", 500, 100, 5, 40, 30, 140},
|
||||
{"Large", 5000, 200, 80, 120, 140, 300},
|
||||
{"Large", 5000, 200, 80, 120, 140, 390},
|
||||
{"Small single", 50, 10, 7, 20, 10, 80},
|
||||
{"Medium single", 500, 10, 5, 40, 20, 60},
|
||||
{"Medium single", 500, 10, 5, 40, 20, 85},
|
||||
{"Large 5", 5000, 15, 80, 120, 80, 200},
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
@@ -95,6 +96,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
db.SetMaxOpenConns(runtime.NumCPU())
|
||||
|
||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Fatal("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
@@ -125,12 +125,12 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createDNSManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
t.Fatalf("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %v", err)
|
||||
}
|
||||
|
||||
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
|
||||
@@ -157,22 +157,22 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
|
||||
am, err := createDNSManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
t.Fatalf("failed to create account manager: %s", err)
|
||||
}
|
||||
|
||||
account, err := initTestDNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
peer1, err := account.FindPeerByPubKey(dnsPeer1Key)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
peer2, err := account.FindPeerByPubKey(dnsPeer2Key)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
|
||||
@@ -123,7 +123,6 @@ func importCsvToSqlite(dataDir string, csvFile string, geonamesdbFile string) er
|
||||
db, err := gorm.Open(sqlite.Open(path.Join(dataDir, geonamesdbFile)), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 1000,
|
||||
PrepareStmt: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -132,8 +132,7 @@ func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(storeStr), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
PrepareStmt: true,
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = types.GroupIssuedIntegration
|
||||
@@ -59,12 +59,12 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
t.Fatalf("failed to create account manager: %s", err)
|
||||
}
|
||||
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
@@ -114,6 +115,18 @@ func NewServer(
|
||||
}
|
||||
|
||||
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
ip := ""
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if ok {
|
||||
ip = p.Addr.String()
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
|
||||
}()
|
||||
|
||||
// todo introduce something more meaningful with the key expiration/rotation
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
|
||||
@@ -725,6 +738,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
|
||||
}()
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||
@@ -777,6 +796,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
|
||||
}()
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
||||
|
||||
@@ -46,7 +46,7 @@ func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, network
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
accountManager.GetPATInfo,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
|
||||
@@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
return make([]*types.UserInfo, 0), nil
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||
return make(map[string]*types.UserInfo), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -281,7 +282,12 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne
|
||||
}
|
||||
if len(router.PeerGroups) > 0 {
|
||||
for _, groupID := range router.PeerGroups {
|
||||
peerCounter += len(groups[groupID].Peers)
|
||||
group, ok := groups[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Warnf("group %s not found", groupID)
|
||||
continue
|
||||
}
|
||||
peerCounter += len(group.Peers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ var usersTestAccount = &types.Account{
|
||||
Issued: types.UserIssuedAPI,
|
||||
},
|
||||
nonDeletableServiceUserID: {
|
||||
Id: serviceUserID,
|
||||
Id: nonDeletableServiceUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: true,
|
||||
NonDeletable: true,
|
||||
@@ -70,10 +70,10 @@ func initUsersTestData() *handler {
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
|
||||
users := make([]*types.UserInfo, 0)
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
|
||||
usersInfos := make(map[string]*types.UserInfo)
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &types.UserInfo{
|
||||
usersInfos[v.Id] = &types.UserInfo{
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
@@ -81,9 +81,9 @@ func initUsersTestData() *handler {
|
||||
IsServiceUser: v.IsServiceUser,
|
||||
NonDeletable: v.NonDeletable,
|
||||
Issued: v.Issued,
|
||||
})
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
return usersInfos, nil
|
||||
},
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
@@ -47,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
claimMaps[jwtclaims.IsToken] = true
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
|
||||
@@ -34,7 +34,8 @@ var testAccount = &types.Account{
|
||||
Domain: domain,
|
||||
Users: map[string]*types.User{
|
||||
userID: {
|
||||
Id: userID,
|
||||
Id: userID,
|
||||
AccountID: accountID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
tokenID: {
|
||||
ID: tokenID,
|
||||
@@ -50,11 +51,11 @@ var testAccount = &types.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
@@ -166,7 +167,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockGetAccountFromPAT,
|
||||
mockGetAccountInfoFromPAT,
|
||||
mockValidateAndParseToken,
|
||||
mockMarkPATUsed,
|
||||
mockCheckUserAccessByJWTGroups,
|
||||
|
||||
@@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{
|
||||
|
||||
func BenchmarkUpdateUser(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000},
|
||||
"Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50},
|
||||
"Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250},
|
||||
"Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700},
|
||||
"Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400},
|
||||
"Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000},
|
||||
"Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500},
|
||||
"Users - XS": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 310},
|
||||
"Users - S": {MinMsPerOpLocal: 0.3, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
|
||||
"Users - M": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 3, MaxMsPerOpCICD: 20},
|
||||
"Users - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 310},
|
||||
"Groups - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 120},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
|
||||
"Users - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 280},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@@ -118,14 +118,14 @@ func BenchmarkGetOneUser(b *testing.B) {
|
||||
|
||||
func BenchmarkGetAllUsers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200},
|
||||
"Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90},
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
|
||||
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
|
||||
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
|
||||
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
|
||||
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
|
||||
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
@@ -141,7 +141,7 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
@@ -152,14 +152,14 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
|
||||
func BenchmarkDeleteUsers(b *testing.B) {
|
||||
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||
"Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000},
|
||||
"Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
|
||||
"Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230},
|
||||
"Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190},
|
||||
"Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800},
|
||||
"Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600},
|
||||
"Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400},
|
||||
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManagement(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Management Service Suite")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,9 +21,7 @@ import (
|
||||
func setupDatabase(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
|
||||
PrepareStmt: true,
|
||||
})
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
|
||||
require.NoError(t, err, "Failed to open database")
|
||||
return db
|
||||
|
||||
@@ -53,8 +53,8 @@ type MockAccountManager struct {
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -69,7 +69,7 @@ type MockAccountManager struct {
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -110,6 +110,7 @@ type MockAccountManager struct {
|
||||
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
|
||||
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
|
||||
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
@@ -165,7 +166,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) (map[string]*types.UserInfo, error) {
|
||||
if am.GetUsersFromAccountFunc != nil {
|
||||
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
|
||||
}
|
||||
@@ -238,12 +239,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
|
||||
if am.GetPATInfoFunc != nil {
|
||||
return am.GetPATInfoFunc(ctx, pat)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
@@ -550,9 +551,9 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
|
||||
}
|
||||
|
||||
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
|
||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error {
|
||||
if am.DeleteRegularUsersFunc != nil {
|
||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
|
||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs, userInfos)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
||||
}
|
||||
@@ -849,3 +850,11 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
|
||||
}
|
||||
|
||||
// BuildUserInfosForAccount mocks BuildUserInfosForAccount of the AccountManager interface
|
||||
func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
|
||||
if am.BuildUserInfosForAccountFunc != nil {
|
||||
return am.BuildUserInfosForAccountFunc(ctx, accountID, initiatorUserID, accountUsers)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
|
||||
}
|
||||
|
||||
@@ -380,12 +380,12 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createNSManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
t.Fatalf("failed to create account manager: %s", err)
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
outNSGroup, err := am.CreateNameServerGroup(
|
||||
@@ -608,12 +608,12 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createNSManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
t.Fatalf("failed to create account manager: %s", err)
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
|
||||
@@ -707,7 +707,7 @@ func TestDeleteNameServerGroup(t *testing.T) {
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
|
||||
@@ -742,7 +742,7 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
t.Fatalf("failed to init testing account: %s", err)
|
||||
}
|
||||
|
||||
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
|
||||
@@ -762,6 +762,7 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
|
||||
func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createNSStore(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -29,7 +30,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -1577,7 +1577,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
// Adding peer to group linked with policy should update account peers and send peer update
|
||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
|
||||
Enabled: true,
|
||||
AccountID: account.Id,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
|
||||
@@ -13,13 +13,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
@@ -93,7 +93,7 @@ func NewPeerNotPartOfAccountError() error {
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
func NewUserNotFoundError(userKey string) error {
|
||||
return Errorf(NotFound, "user not found: %s", userKey)
|
||||
return Errorf(NotFound, "user: %s not found", userKey)
|
||||
}
|
||||
|
||||
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
|
||||
@@ -191,3 +191,18 @@ func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
|
||||
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
|
||||
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
|
||||
}
|
||||
|
||||
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
|
||||
func NewServiceUserRoleInvalidError() error {
|
||||
return Errorf(InvalidArgument, "can't create a service user with owner role")
|
||||
}
|
||||
|
||||
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
|
||||
// to delete a user with the owner role.
|
||||
func NewOwnerDeletePermissionError() error {
|
||||
return Errorf(PermissionDenied, "can't delete a user with the owner role")
|
||||
}
|
||||
|
||||
func NewPATNotFoundError(patID string) error {
|
||||
return Errorf(NotFound, "PAT: %s not found", patID)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
|
||||
usersToSave := make([]types.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
}
|
||||
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
||||
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&users)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save users to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save user to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -450,7 +444,7 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&groups)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
}
|
||||
@@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
|
||||
var token types.PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
Where("personal_access_tokens.id = ?", patID).First(&user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return nil, status.NewPATNotFoundError(patID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
if token.UserID == "" {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
@@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete user from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||
var users []*types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
@@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||
var createdBy string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return createdBy, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user types.User
|
||||
@@ -956,7 +969,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig(SqliteStoreEngine))
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -966,7 +979,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
|
||||
// NewPostgresqlStore creates a new Postgres store.
|
||||
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), getGormConfig(PostgresStoreEngine))
|
||||
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -976,7 +989,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
|
||||
// NewMysqlStore creates a new MySQL store.
|
||||
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig(MysqlStoreEngine))
|
||||
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -984,15 +997,10 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
|
||||
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
|
||||
}
|
||||
|
||||
func getGormConfig(engine Engine) *gorm.Config {
|
||||
prepStmt := true
|
||||
if engine == SqliteStoreEngine {
|
||||
prepStmt = false
|
||||
}
|
||||
func getGormConfig() *gorm.Config {
|
||||
return &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: prepStmt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2061,3 +2069,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError(hashedToken)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError(patID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get pat from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetUserPATs retrieves personal access tokens for a user.
|
||||
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
||||
var pats []*types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
|
||||
}
|
||||
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used.
|
||||
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
||||
patCopy := types.PersonalAccessToken{
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"last_used"}
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, patID).Updates(&patCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark pat as used")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError(patID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePAT saves a personal access token to the database.
|
||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save pat to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a personal access token from the database.
|
||||
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete pat from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError(patID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,40 +37,44 @@ import (
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestSqlite_NewStore(t *testing.T) {
|
||||
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
|
||||
t.Helper()
|
||||
for _, engine := range supportedEngines {
|
||||
if os.Getenv("NETBIRD_STORE_ENGINE") != "" && os.Getenv("NETBIRD_STORE_ENGINE") != string(engine) {
|
||||
continue
|
||||
}
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(engine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), testDataFile, t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
t.Run(string(engine), func(t *testing.T) {
|
||||
f(t, store)
|
||||
})
|
||||
os.Unsetenv("NETBIRD_STORE_ENGINE")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NewStore(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
}
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
if store == nil {
|
||||
t.Errorf("expected to create a new Store")
|
||||
}
|
||||
if len(store.GetAllAccounts(context.Background())) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlite_SaveAccount_Large(t *testing.T) {
|
||||
func Test_SaveAccount_Large(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Run("SQLite", func(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
runLargeTest(t, store)
|
||||
})
|
||||
|
||||
// create store outside to have a better time counter for the test
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
runLargeTest(t, store)
|
||||
})
|
||||
}
|
||||
@@ -215,77 +219,74 @@ func randomIPv4() net.IP {
|
||||
return net.IP(b)
|
||||
}
|
||||
|
||||
func TestSqlite_SaveAccount(t *testing.T) {
|
||||
func Test_SaveAccount(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey, _ := types.GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey, _ = types.GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
|
||||
setupKey, _ = types.GenerateDefaultSetupKey()
|
||||
account2.SetupKeys[setupKey.Key] = setupKey
|
||||
account2.Peers["testpeer2"] = &nbpeer.Peer{
|
||||
Key: "peerkey2",
|
||||
IP: net.IP{127, 0, 0, 2},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name 2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
err = store.SaveAccount(context.Background(), account2)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account2)
|
||||
require.NoError(t, err)
|
||||
if len(store.GetAllAccounts(context.Background())) != 2 {
|
||||
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
|
||||
if len(store.GetAllAccounts(context.Background())) != 2 {
|
||||
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
|
||||
}
|
||||
a, err := store.GetAccount(context.Background(), account.Id)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
a, err := store.GetAccount(context.Background(), account.Id)
|
||||
if a == nil {
|
||||
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
|
||||
}
|
||||
if a != nil && len(a.Policies) != 1 {
|
||||
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies) != 1 {
|
||||
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
|
||||
}
|
||||
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||
return
|
||||
}
|
||||
|
||||
if a != nil && len(a.Policies[0].Rules) != 1 {
|
||||
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
|
||||
return
|
||||
}
|
||||
if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil {
|
||||
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil {
|
||||
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil {
|
||||
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil {
|
||||
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil {
|
||||
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil {
|
||||
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
|
||||
if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil {
|
||||
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil {
|
||||
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
@@ -402,27 +403,24 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccount(t *testing.T) {
|
||||
func Test_GetAccount(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
|
||||
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
account, err := store.GetAccount(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, account.Id, "account id should match")
|
||||
|
||||
account, err := store.GetAccount(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, account.Id, "account id should match")
|
||||
|
||||
_, err = store.GetAccount(context.Background(), "non-existing-account")
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
_, err = store.GetAccount(context.Background(), "non-existing-account")
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePeer(t *testing.T) {
|
||||
@@ -580,74 +578,45 @@ func TestSqlStore_SavePeerLocation(t *testing.T) {
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
|
||||
existingDomain := "test.com"
|
||||
|
||||
existingDomain := "test.com"
|
||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
||||
func Test_GetTokenIDByHashedToken(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
|
||||
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
|
||||
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
@@ -962,23 +931,6 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
require.Equal(t, id, token)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
|
||||
func TestSqlite_GetTakenIPs(t *testing.T) {
|
||||
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
@@ -1182,7 +1134,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSqlite_GetAccoundUsers(t *testing.T) {
|
||||
func TestSqlStore_GetAccountUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
if err != nil {
|
||||
@@ -1371,6 +1323,14 @@ func TestSqlStore_SaveGroups(t *testing.T) {
|
||||
}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups[1].Peers = []string{}
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
|
||||
require.NoError(t, err)
|
||||
|
||||
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, groups[1], group)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||
@@ -2935,3 +2895,392 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) {
|
||||
|
||||
t.Logf("Test completed")
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountCreatedBy(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectError bool
|
||||
createdBy string
|
||||
}{
|
||||
{
|
||||
name: "existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectError: false,
|
||||
createdBy: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Empty(t, createdBy)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, createdBy)
|
||||
require.Equal(t, tt.createdBy, createdBy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserByUserID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing user",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing user",
|
||||
userID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty user ID",
|
||||
userID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, user)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, tt.userID, user.Id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserByPATID(t *testing.T) {
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUser(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
user := &types.User{
|
||||
Id: "user-id",
|
||||
AccountID: accountID,
|
||||
Role: types.UserRoleAdmin,
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"groupA", "groupB"},
|
||||
Blocked: false,
|
||||
LastLogin: util.ToPtr(time.Now().UTC()),
|
||||
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
||||
Issued: types.UserIssuedIntegration,
|
||||
}
|
||||
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, user.Id, saveUser.Id)
|
||||
require.Equal(t, user.AccountID, saveUser.AccountID)
|
||||
require.Equal(t, user.Role, saveUser.Role)
|
||||
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
|
||||
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
|
||||
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||
require.Equal(t, user.Issued, saveUser.Issued)
|
||||
require.Equal(t, user.Blocked, saveUser.Blocked)
|
||||
require.Equal(t, user.IsServiceUser, saveUser.IsServiceUser)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUsers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accountUsers, 2)
|
||||
|
||||
users := []*types.User{
|
||||
{
|
||||
Id: "user-1",
|
||||
AccountID: accountID,
|
||||
Issued: "api",
|
||||
AutoGroups: []string{"groupA", "groupB"},
|
||||
},
|
||||
{
|
||||
Id: "user-2",
|
||||
AccountID: accountID,
|
||||
Issued: "integration",
|
||||
AutoGroups: []string{"groupA"},
|
||||
},
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accountUsers, 4)
|
||||
|
||||
users[1].AutoGroups = []string{"groupA", "groupC"}
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, users[1].Id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, users[1].AutoGroups, user.AutoGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUser(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, user)
|
||||
|
||||
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, userPATs, 0)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPATByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
patID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing PAT",
|
||||
patID: "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing PAT",
|
||||
patID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty PAT ID",
|
||||
patID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, pat)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pat)
|
||||
require.Equal(t, tt.patID, pat.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetUserPATs(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, userPATs, 1)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPATByHashedToken(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID)
|
||||
}
|
||||
|
||||
func TestSqlStore_MarkPATUsed(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
|
||||
require.NoError(t, err)
|
||||
now := time.Now().UTC()
|
||||
require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now")
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePAT(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
pat := &types.PersonalAccessToken{
|
||||
ID: "pat-id",
|
||||
UserID: userID,
|
||||
Name: "token",
|
||||
HashedToken: "SoMeHaShEdToKeN",
|
||||
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
|
||||
CreatedBy: userID,
|
||||
CreatedAt: time.Now().UTC().Add(time.Hour),
|
||||
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
|
||||
}
|
||||
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pat.ID, savePAT.ID)
|
||||
require.Equal(t, pat.UserID, savePAT.UserID)
|
||||
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
|
||||
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
|
||||
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
|
||||
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
|
||||
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePAT(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
|
||||
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID)
|
||||
require.NoError(t, err)
|
||||
|
||||
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, pat)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUsers_LargeBatch(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accountUsers, 2)
|
||||
|
||||
usersToSave := make([]*types.User, 0)
|
||||
|
||||
for i := 1; i <= 8000; i++ {
|
||||
usersToSave = append(usersToSave, &types.User{
|
||||
Id: fmt.Sprintf("user-%d", i),
|
||||
AccountID: accountID,
|
||||
Role: types.UserRoleUser,
|
||||
})
|
||||
}
|
||||
|
||||
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, usersToSave)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8002, len(accountUsers))
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
accountGroups, err := store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accountGroups, 3)
|
||||
|
||||
groupsToSave := make([]*types.Group, 0)
|
||||
|
||||
for i := 1; i <= 8000; i++ {
|
||||
groupsToSave = append(groupsToSave, &types.Group{
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
AccountID: accountID,
|
||||
Name: fmt.Sprintf("group-%d", i),
|
||||
})
|
||||
}
|
||||
|
||||
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groupsToSave)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8003, len(accountGroups))
|
||||
}
|
||||
|
||||
@@ -9,11 +9,16 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -59,21 +64,30 @@ type Store interface {
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
|
||||
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
||||
SaveAccount(ctx context.Context, account *types.Account) error
|
||||
DeleteAccount(ctx context.Context, account *types.Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
|
||||
SaveUsers(accountID string, users map[string]*types.User) error
|
||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
|
||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||
@@ -184,6 +198,8 @@ const (
|
||||
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
|
||||
)
|
||||
|
||||
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
|
||||
|
||||
func getStoreEngineFromEnv() Engine {
|
||||
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
|
||||
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
|
||||
@@ -192,7 +208,7 @@ func getStoreEngineFromEnv() Engine {
|
||||
}
|
||||
|
||||
value := Engine(strings.ToLower(kind))
|
||||
if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine {
|
||||
if slices.Contains(supportedEngines, value) {
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -319,7 +335,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig(kind))
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -340,51 +356,126 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
|
||||
}
|
||||
|
||||
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
|
||||
if kind == PostgresStoreEngine {
|
||||
cleanUp, err := testutil.CreatePostgresTestContainer()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
var cleanup func()
|
||||
var err error
|
||||
switch kind {
|
||||
case PostgresStoreEngine:
|
||||
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
|
||||
case MysqlStoreEngine:
|
||||
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
|
||||
default:
|
||||
cleanup = func() {
|
||||
// sqlite doesn't need to be cleaned up
|
||||
}
|
||||
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return store, cleanUp, nil
|
||||
}
|
||||
|
||||
if kind == MysqlStoreEngine {
|
||||
cleanUp, err := testutil.CreateMysqlTestContainer()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dsn, ok := os.LookupEnv(mysqlDsnEnv)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
||||
}
|
||||
|
||||
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return store, cleanUp, nil
|
||||
if err != nil {
|
||||
return nil, cleanup, fmt.Errorf("failed to create test store: %v", err)
|
||||
}
|
||||
|
||||
closeConnection := func() {
|
||||
cleanup()
|
||||
store.Close(ctx)
|
||||
}
|
||||
|
||||
return store, closeConnection, nil
|
||||
}
|
||||
|
||||
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
|
||||
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
|
||||
var err error
|
||||
_, err = testutil.CreatePostgresTestContainer()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
|
||||
}
|
||||
|
||||
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
||||
if err != nil {
|
||||
return nil, cleanup, err
|
||||
}
|
||||
|
||||
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
|
||||
if err != nil {
|
||||
return nil, cleanup, err
|
||||
}
|
||||
|
||||
return store, cleanup, nil
|
||||
}
|
||||
|
||||
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
|
||||
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
|
||||
var err error
|
||||
_, err = testutil.CreateMysqlTestContainer()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
dsn, ok := os.LookupEnv(mysqlDsnEnv)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
|
||||
}
|
||||
|
||||
dsn, cleanup, err := createRandomDB(dsn, db, kind)
|
||||
if err != nil {
|
||||
return nil, cleanup, err
|
||||
}
|
||||
|
||||
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return store, cleanup, nil
|
||||
}
|
||||
|
||||
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
|
||||
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
|
||||
|
||||
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create database: %v", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
cleanup := func() {
|
||||
switch engine {
|
||||
case PostgresStoreEngine:
|
||||
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
|
||||
case MysqlStoreEngine:
|
||||
// err = killMySQLConnections(dsn, dbName)
|
||||
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("failed to drop database %s: %v", dbName, err)
|
||||
panic(err)
|
||||
}
|
||||
sqlDB, _ := db.DB()
|
||||
_ = sqlDB.Close()
|
||||
}
|
||||
|
||||
return replaceDBName(dsn, dbName), cleanup, nil
|
||||
}
|
||||
|
||||
func replaceDBName(dsn, newDBName string) string {
|
||||
re := regexp.MustCompile(`(?P<pre>[:/@])(?P<dbname>[^/?]+)(?P<post>\?|$)`)
|
||||
return re.ReplaceAllString(dsn, `${pre}`+newDBName+`${post}`)
|
||||
}
|
||||
|
||||
func loadSQL(db *gorm.DB, filepath string) error {
|
||||
sqlContent, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
|
||||
2
management/server/testdata/store.sql
vendored
2
management/server/testdata/store.sql
vendored
@@ -37,7 +37,7 @@ CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`);
|
||||
CREATE INDEX `idx_networks_id` ON `networks`(`id`);
|
||||
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
|
||||
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
|
||||
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');
|
||||
|
||||
@@ -22,7 +22,7 @@ func CreateMysqlTestContainer() (func(), error) {
|
||||
myContainer, err := mysql.RunContainer(ctx,
|
||||
testcontainers.WithImage("mlsmaycon/warmed-mysql:8"),
|
||||
mysql.WithDatabase("testing"),
|
||||
mysql.WithUsername("testing"),
|
||||
mysql.WithUsername("root"),
|
||||
mysql.WithPassword("testing"),
|
||||
testcontainers.WithWaitStrategy(
|
||||
wait.ForLog("/usr/sbin/mysqld: ready for connections").
|
||||
@@ -34,6 +34,7 @@ func CreateMysqlTestContainer() (func(), error) {
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
os.Unsetenv("NETBIRD_STORE_ENGINE_MYSQL_DSN")
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancelFunc()
|
||||
if err = myContainer.Terminate(timeoutCtx); err != nil {
|
||||
@@ -68,6 +69,7 @@ func CreatePostgresTestContainer() (func(), error) {
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
os.Unsetenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN")
|
||||
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
|
||||
defer cancelFunc()
|
||||
if err = pgContainer.Terminate(timeoutCtx); err != nil {
|
||||
|
||||
@@ -75,7 +75,7 @@ type PersonalAccessTokenGenerated struct {
|
||||
|
||||
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
|
||||
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
|
||||
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
|
||||
hashedToken, plainToken, err := generateNewToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -84,6 +84,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
|
||||
return &PersonalAccessTokenGenerated{
|
||||
PersonalAccessToken: PersonalAccessToken{
|
||||
ID: xid.New().String(),
|
||||
UserID: targetID,
|
||||
Name: name,
|
||||
HashedToken: hashedToken,
|
||||
ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)),
|
||||
|
||||
@@ -80,7 +80,7 @@ type User struct {
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -45,7 +46,7 @@ const (
|
||||
)
|
||||
|
||||
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Error when creating store: %s", err)
|
||||
}
|
||||
@@ -53,13 +54,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
Store: s,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
@@ -81,7 +82,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
|
||||
assert.Equal(t, pat.ID, tokenID)
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||
}
|
||||
@@ -855,7 +856,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
{
|
||||
name: "Delete non-existent user",
|
||||
userIDs: []string{"non-existent-user"},
|
||||
expectedReasons: []string{"target user: non-existent-user not found"},
|
||||
expectedReasons: []string{"user: non-existent-user not found"},
|
||||
expectedNotDeleted: []string{},
|
||||
},
|
||||
{
|
||||
@@ -867,7 +868,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
|
||||
userInfos, err := am.BuildUserInfosForAccount(context.Background(), mockAccountID, mockUserID, maps.Values(account.Users))
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs, userInfos)
|
||||
if len(tc.expectedReasons) > 0 {
|
||||
assert.Error(t, err)
|
||||
var foundExpectedErrors int
|
||||
|
||||
Reference in New Issue
Block a user