Merge branch 'main' into feature/port-forwarding

This commit is contained in:
Viktor Liu
2025-02-20 11:31:04 +01:00
191 changed files with 10566 additions and 3093 deletions

View File

@@ -68,7 +68,7 @@ type AccountManager interface {
SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error)
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
@@ -80,7 +80,7 @@ type AccountManager interface {
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
@@ -97,7 +97,7 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error)
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
@@ -150,6 +150,7 @@ type AccountManager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
}
type DefaultAccountManager struct {
@@ -622,6 +623,12 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
if user.Role != types.UserRoleOwner {
return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account")
}
userInfosMap, err := am.BuildUserInfosForAccount(ctx, accountID, userID, maps.Values(account.Users))
if err != nil {
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.IsServiceUser {
continue
@@ -631,13 +638,23 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
continue
}
deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id)
userInfo, ok := userInfosMap[otherUser.Id]
if !ok {
return status.Errorf(status.NotFound, "user info not found for user %s", otherUser.Id)
}
_, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, userInfo)
if deleteUserErr != nil {
return deleteUserErr
}
}
err = am.deleteRegularUser(ctx, account, userID, userID)
userInfo, ok := userInfosMap[userID]
if !ok {
return status.Errorf(status.NotFound, "user info not found for user %s", userID)
}
_, err = am.deleteRegularUser(ctx, accountID, userID, userInfo)
if err != nil {
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
return err
@@ -694,20 +711,8 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
cachedAccount := &types.Account{
Id: accountID,
Users: make(map[string]*types.User),
}
for _, user := range accountUsers {
cachedAccount.Users[user.Id] = user
}
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
@@ -783,10 +788,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) {
users := make(map[string]userLoggedInOnce, len(account.Users))
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
users := make(map[string]userLoggedInOnce, len(accountUsers))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
for _, user := range accountUsers {
if user.IsServiceUser {
continue
}
@@ -795,8 +805,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
}
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
}
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(ctx, users, account.Id)
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
userData, err := am.lookupCache(ctx, users, accountID)
if err != nil {
return nil, err
}
@@ -809,13 +819,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err
}
key := user.IntegrationReference.CacheKey(account.Id, userID)
key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
@@ -1055,9 +1065,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
usersMap := make(map[string]*types.User)
usersMap[claims.UserId] = types.NewRegularUser(claims.UserId)
err := am.Store.SaveUsers(domainAccountID, usersMap)
newUser := types.NewRegularUser(claims.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
if err != nil {
return "", err
}
@@ -1080,12 +1090,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
return nil
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
@@ -1095,17 +1100,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
return
}
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
}()
}
@@ -1114,33 +1119,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = util.ToPtr(time.Now().UTC())
return am.Store.SaveAccount(ctx, account)
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
}
// GetAccount returns an account associated with this account ID.
@@ -1148,52 +1127,64 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return nil, nil, "", "", err
}
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
return user, pat, domain, category, nil
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
if len(token) != types.PATLength {
return nil, nil, nil, fmt.Errorf("token has wrong length")
return nil, nil, fmt.Errorf("token has incorrect length")
}
prefix := token[:len(types.PATPrefix)]
if prefix != types.PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
return nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, nil, fmt.Errorf("token checksum does not match")
return nil, nil, fmt.Errorf("token checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
var user *types.User
var pat *types.PersonalAccessToken
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
if err != nil {
return err
}
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
return err
})
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return nil, nil, nil, err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return nil, nil, nil, err
}
pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
}
return account, user, pat, nil
return user, pat, nil
}
// GetAccountByID returns an account associated with this account ID.
@@ -1339,7 +1330,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
updatedGroups, err := updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}

View File

@@ -733,6 +733,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
UserID: "someUser",
HashedToken: encodedHashedToken,
},
},
@@ -746,14 +747,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
user, pat, _, _, err := am.GetPATInfo(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
assert.Equal(t, "account_id", account.Id)
assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
@@ -3018,11 +3019,11 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 1, 5, 3, 19},
{"Medium", 500, 100, 7, 22, 10, 90},
{"Large", 5000, 200, 65, 110, 60, 240},
{"Small", 50, 5, 1, 5, 3, 24},
{"Medium", 500, 100, 7, 22, 10, 135},
{"Large", 5000, 200, 65, 110, 60, 320},
{"Small single", 50, 10, 1, 4, 3, 80},
{"Medium single", 500, 10, 7, 13, 10, 37},
{"Medium single", 500, 10, 7, 13, 10, 43},
{"Large 5", 5000, 15, 65, 80, 60, 220},
}
@@ -3087,8 +3088,8 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 2, 10, 3, 35},
{"Medium", 500, 100, 5, 40, 20, 110},
{"Large", 5000, 200, 60, 100, 120, 260},
{"Medium", 500, 100, 5, 40, 20, 140},
{"Large", 5000, 200, 60, 100, 120, 320},
{"Small single", 50, 10, 2, 10, 5, 40},
{"Medium single", 500, 10, 5, 40, 10, 60},
{"Large 5", 5000, 15, 60, 100, 60, 180},
@@ -3163,9 +3164,9 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
}{
{"Small", 50, 5, 7, 20, 10, 80},
{"Medium", 500, 100, 5, 40, 30, 140},
{"Large", 5000, 200, 80, 120, 140, 300},
{"Large", 5000, 200, 80, 120, 140, 390},
{"Small single", 50, 10, 7, 20, 10, 80},
{"Medium single", 500, 10, 5, 40, 20, 60},
{"Medium single", 500, 10, 5, 40, 20, 85},
{"Large 5", 5000, 15, 80, 120, 80, 200},
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"path/filepath"
"runtime"
"time"
_ "github.com/mattn/go-sqlite3"
@@ -95,6 +96,7 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
if err != nil {
return nil, err
}
db.SetMaxOpenConns(runtime.NumCPU())
crypt, err := NewFieldEncrypt(encryptionKey)
if err != nil {

View File

@@ -43,7 +43,7 @@ func TestGetDNSSettings(t *testing.T) {
account, err := initTestDNSAccount(t, am)
if err != nil {
t.Fatal("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
@@ -125,12 +125,12 @@ func TestSaveDNSSettings(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
am, err := createDNSManager(t)
if err != nil {
t.Error("failed to create account manager")
t.Fatalf("failed to create account manager")
}
account, err := initTestDNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %v", err)
}
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
@@ -157,22 +157,22 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
am, err := createDNSManager(t)
if err != nil {
t.Error("failed to create account manager")
t.Fatalf("failed to create account manager: %s", err)
}
account, err := initTestDNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
peer1, err := account.FindPeerByPubKey(dnsPeer1Key)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
peer2, err := account.FindPeerByPubKey(dnsPeer2Key)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)

View File

@@ -123,7 +123,6 @@ func importCsvToSqlite(dataDir string, csvFile string, geonamesdbFile string) er
db, err := gorm.Open(sqlite.Open(path.Join(dataDir, geonamesdbFile)), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 1000,
PrepareStmt: true,
})
if err != nil {
return err

View File

@@ -132,8 +132,7 @@ func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
}
db, err := gorm.Open(sqlite.Open(storeStr), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, err

View File

@@ -29,7 +29,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
_, account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
for _, group := range account.Groups {
group.Issued = types.GroupIssuedIntegration
@@ -59,12 +59,12 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
am, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
t.Fatalf("failed to create account manager: %s", err)
}
_, account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
testCases := []struct {

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
@@ -114,6 +115,18 @@ func NewServer(
}
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
ip := ""
p, ok := peer.FromContext(ctx)
if ok {
ip = p.Addr.String()
}
log.WithContext(ctx).Tracef("GetServerKey request from %s", ip)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetServerKey from %s took %v", ip, time.Since(start))
}()
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
@@ -725,6 +738,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
// This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow for pubKey: %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
@@ -777,6 +796,12 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now()
defer func() {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow for pubKey %s took %v", req.WgPubKey, time.Since(start))
}()
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)

View File

@@ -46,7 +46,7 @@ func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, network
)
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
accountManager.GetPATInfo,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,

View File

@@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
return make([]*types.UserInfo, 0), nil
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
return make(map[string]*types.UserInfo), nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/groups"
@@ -281,7 +282,12 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne
}
if len(router.PeerGroups) > 0 {
for _, groupID := range router.PeerGroups {
peerCounter += len(groups[groupID].Peers)
group, ok := groups[groupID]
if !ok {
log.WithContext(ctx).Warnf("group %s not found", groupID)
continue
}
peerCounter += len(group.Peers)
}
}
}

View File

@@ -52,7 +52,7 @@ var usersTestAccount = &types.Account{
Issued: types.UserIssuedAPI,
},
nonDeletableServiceUserID: {
Id: serviceUserID,
Id: nonDeletableServiceUserID,
Role: "admin",
IsServiceUser: true,
NonDeletable: true,
@@ -70,10 +70,10 @@ func initUsersTestData() *handler {
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
users := make([]*types.UserInfo, 0)
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) (map[string]*types.UserInfo, error) {
usersInfos := make(map[string]*types.UserInfo)
for _, v := range usersTestAccount.Users {
users = append(users, &types.UserInfo{
usersInfos[v.Id] = &types.UserInfo{
ID: v.Id,
Role: string(v.Role),
Name: "",
@@ -81,9 +81,9 @@ func initUsersTestData() *handler {
IsServiceUser: v.IsServiceUser,
NonDeletable: v.NonDeletable,
Issued: v.Issued,
})
}
}
return users, nil
return usersInfos, nil
},
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
if userID != existingUserID {

View File

@@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@@ -47,7 +47,7 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
@@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint

View File

@@ -34,7 +34,8 @@ var testAccount = &types.Account{
Domain: domain,
Users: map[string]*types.User{
userID: {
Id: userID,
Id: userID,
AccountID: accountID,
PATs: map[string]*types.PersonalAccessToken{
tokenID: {
ID: tokenID,
@@ -50,11 +51,11 @@ var testAccount = &types.Account{
},
}
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
@@ -166,7 +167,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
)
authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockGetAccountInfoFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockCheckUserAccessByJWTGroups,

View File

@@ -35,14 +35,14 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{
func BenchmarkUpdateUser(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000},
"Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50},
"Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250},
"Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700},
"Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400},
"Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000},
"Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000},
"Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500},
"Users - XS": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 310},
"Users - S": {MinMsPerOpLocal: 0.3, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
"Users - M": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 3, MaxMsPerOpCICD: 20},
"Users - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 310},
"Groups - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 120},
"Setup Keys - L": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 50},
"Users - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 280},
}
log.SetOutput(io.Discard)
@@ -118,14 +118,14 @@ func BenchmarkGetOneUser(b *testing.B) {
func BenchmarkGetAllUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30},
"Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200},
"Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90},
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 10},
"Users - M": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 15},
"Users - L": {MinMsPerOpLocal: 10, MaxMsPerOpLocal: 20, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 50},
"Peers - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 55},
"Groups - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
"Setup Keys - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 25, MinMsPerOpCICD: 25, MaxMsPerOpCICD: 55},
"Users - XL": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 120, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300},
}
log.SetOutput(io.Discard)
@@ -141,7 +141,7 @@ func BenchmarkGetAllUsers(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
@@ -152,14 +152,14 @@ func BenchmarkGetAllUsers(b *testing.B) {
func BenchmarkDeleteUsers(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000},
"Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200},
"Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230},
"Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190},
"Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800},
"Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500},
"Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600},
"Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400},
"Users - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
"Users - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15},
}
log.SetOutput(io.Discard)

View File

@@ -1,13 +0,0 @@
package server_test
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestManagement(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Management Service Suite")
}

File diff suppressed because it is too large Load Diff

View File

@@ -21,9 +21,7 @@ import (
func setupDatabase(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
PrepareStmt: true,
})
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
require.NoError(t, err, "Failed to open database")
return db

View File

@@ -53,8 +53,8 @@ type MockAccountManager struct {
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetPATInfoFunc func(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -69,7 +69,7 @@ type MockAccountManager struct {
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
@@ -110,6 +110,7 @@ type MockAccountManager struct {
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
}
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@@ -165,7 +166,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) (map[string]*types.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
}
@@ -238,12 +239,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat)
// GetPATInfo mock implementation of GetPATInfo from server.AccountManager interface
func (am *MockAccountManager) GetPATInfo(ctx context.Context, pat string) (*types.User, *types.PersonalAccessToken, string, string, error) {
if am.GetPATInfoFunc != nil {
return am.GetPATInfoFunc(ctx, pat)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetPATInfo is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
@@ -550,9 +551,9 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
}
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error {
if am.DeleteRegularUsersFunc != nil {
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs, userInfos)
}
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
}
@@ -849,3 +850,11 @@ func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peer
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}
// BuildUserInfosForAccount mocks BuildUserInfosForAccount of the AccountManager interface
func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
if am.BuildUserInfosForAccountFunc != nil {
return am.BuildUserInfosForAccountFunc(ctx, accountID, initiatorUserID, accountUsers)
}
return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented")
}

View File

@@ -380,12 +380,12 @@ func TestCreateNameServerGroup(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
t.Fatalf("failed to create account manager: %s", err)
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
outNSGroup, err := am.CreateNameServerGroup(
@@ -608,12 +608,12 @@ func TestSaveNameServerGroup(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
t.Fatalf("failed to create account manager: %s", err)
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
@@ -707,7 +707,7 @@ func TestDeleteNameServerGroup(t *testing.T) {
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
@@ -742,7 +742,7 @@ func TestGetNameServerGroup(t *testing.T) {
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
t.Fatalf("failed to init testing account: %s", err)
}
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
@@ -762,6 +762,7 @@ func TestGetNameServerGroup(t *testing.T) {
func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
t.Helper()
store, err := createNSStore(t)
if err != nil {
return nil, err

View File

@@ -13,6 +13,7 @@ import (
"testing"
"time"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -29,7 +30,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
@@ -1577,7 +1577,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,

View File

@@ -13,13 +13,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"

View File

@@ -93,7 +93,7 @@ func NewPeerNotPartOfAccountError() error {
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
func NewUserNotFoundError(userKey string) error {
return Errorf(NotFound, "user not found: %s", userKey)
return Errorf(NotFound, "user: %s not found", userKey)
}
// NewPeerNotRegisteredError creates a new Error with NotFound type for a missing peer
@@ -191,3 +191,18 @@ func NewResourceNotPartOfNetworkError(resourceID, networkID string) error {
func NewRouterNotPartOfNetworkError(routerID, networkID string) error {
return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID)
}
// NewServiceUserRoleInvalidError creates a new Error with InvalidArgument type for creating a service user with owner role
func NewServiceUserRoleInvalidError() error {
return Errorf(InvalidArgument, "can't create a service user with owner role")
}
// NewOwnerDeletePermissionError creates a new Error with PermissionDenied type for attempting
// to delete a user with the owner role.
func NewOwnerDeletePermissionError() error {
return Errorf(PermissionDenied, "can't delete a user with the owner role")
}
func NewPATNotFoundError(patID string) error {
return Errorf(NotFound, "PAT: %s not found", patID)
}

View File

@@ -15,6 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/util"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
@@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
}
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
usersToSave := make([]types.User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
}
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
if err != nil {
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error {
if len(users) == 0 {
return nil
}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&users)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save users to store")
}
return nil
}
@@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user to store")
}
return nil
}
@@ -450,7 +444,7 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
return nil
}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
@@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return token.ID, nil
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
var token types.PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
var user types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).First(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
return nil, status.NewPATNotFoundError(patID)
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user types.User
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
@@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
var user types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
if result.Error != nil {
return result.Error
}
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
})
if err != nil {
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete user from store")
}
return nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
var users []*types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
@@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
return accountSettings.Settings, nil
}
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
var createdBy string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError(accountID)
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return createdBy, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user types.User
@@ -956,7 +969,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), getGormConfig(SqliteStoreEngine))
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
if err != nil {
return nil, err
}
@@ -966,7 +979,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
// NewPostgresqlStore creates a new Postgres store.
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(postgres.Open(dsn), getGormConfig(PostgresStoreEngine))
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
if err != nil {
return nil, err
}
@@ -976,7 +989,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
// NewMysqlStore creates a new MySQL store.
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig(MysqlStoreEngine))
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig())
if err != nil {
return nil, err
}
@@ -984,15 +997,10 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
}
func getGormConfig(engine Engine) *gorm.Config {
prepStmt := true
if engine == SqliteStoreEngine {
prepStmt = false
}
func getGormConfig() *gorm.Config {
return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 400,
PrepareStmt: prepStmt,
}
}
@@ -2061,3 +2069,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
return nil
}
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(hashedToken)
}
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
}
return &pat, nil
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
var pat types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(patID)
}
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get pat from store")
}
return &pat, nil
}
// GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
var pats []*types.PersonalAccessToken
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
}
return pats, nil
}
// MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := types.PersonalAccessToken{
LastUsed: util.ToPtr(time.Now().UTC()),
}
fieldsToUpdate := []string{"last_used"}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate).
Where(idQueryCondition, patID).Updates(&patCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pat as used")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError(patID)
}
return nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
return status.Errorf(status.Internal, "failed to save pat to store")
}
return nil
}
// DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete pat from store")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError(patID)
}
return nil
}

View File

@@ -37,40 +37,44 @@ import (
nbroute "github.com/netbirdio/netbird/route"
)
func TestSqlite_NewStore(t *testing.T) {
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
t.Helper()
for _, engine := range supportedEngines {
if os.Getenv("NETBIRD_STORE_ENGINE") != "" && os.Getenv("NETBIRD_STORE_ENGINE") != string(engine) {
continue
}
t.Setenv("NETBIRD_STORE_ENGINE", string(engine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), testDataFile, t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
t.Run(string(engine), func(t *testing.T) {
f(t, store)
})
os.Unsetenv("NETBIRD_STORE_ENGINE")
}
}
func Test_NewStore(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
if store == nil {
t.Errorf("expected to create a new Store")
}
if len(store.GetAllAccounts(context.Background())) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
}
})
}
func TestSqlite_SaveAccount_Large(t *testing.T) {
func Test_SaveAccount_Large(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Run("SQLite", func(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
runLargeTest(t, store)
})
// create store outside to have a better time counter for the test
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
t.Run("PostgreSQL", func(t *testing.T) {
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
runLargeTest(t, store)
})
}
@@ -215,77 +219,74 @@ func randomIPv4() net.IP {
return net.IP(b)
}
func TestSqlite_SaveAccount(t *testing.T) {
func Test_SaveAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = types.GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = types.GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
IP: net.IP{127, 0, 0, 2},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account2)
require.NoError(t, err)
err = store.SaveAccount(context.Background(), account2)
require.NoError(t, err)
if len(store.GetAllAccounts(context.Background())) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
}
if len(store.GetAllAccounts(context.Background())) != 2 {
t.Errorf("expecting 2 Accounts to be stored after SaveAccount()")
}
a, err := store.GetAccount(context.Background(), account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
a, err := store.GetAccount(context.Background(), account.Id)
if a == nil {
t.Errorf("expecting Account to be stored after SaveAccount(): %v", err)
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies) != 1 {
t.Errorf("expecting Account to have one policy stored after SaveAccount(), got %d", len(a.Policies))
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a != nil && len(a.Policies[0].Rules) != 1 {
t.Errorf("expecting Account to have one policy rule stored after SaveAccount(), got %d", len(a.Policies[0].Rules))
return
}
if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil {
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil {
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil {
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil {
t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil {
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil {
t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil {
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
}
if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil {
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err)
}
})
}
func TestSqlite_DeleteAccount(t *testing.T) {
@@ -402,27 +403,24 @@ func TestSqlite_DeleteAccount(t *testing.T) {
}
}
func TestSqlite_GetAccount(t *testing.T) {
func Test_GetAccount(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match")
account, err := store.GetAccount(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, account.Id, "account id should match")
_, err = store.GetAccount(context.Background(), "non-existing-account")
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
_, err = store.GetAccount(context.Background(), "non-existing-account")
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
})
}
func TestSqlStore_SavePeer(t *testing.T) {
@@ -580,74 +578,45 @@ func TestSqlStore_SavePeerLocation(t *testing.T) {
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
existingDomain := "test.com"
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
})
}
func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
func Test_GetTokenIDByHashedToken(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
runTestForAllEngines(t, "../testdata/store.sql", func(t *testing.T, store Store) {
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
require.NoError(t, err)
require.Equal(t, id, token)
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
require.NoError(t, err)
require.Equal(t, id, token)
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func TestSqlite_GetUserByTokenID(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
})
}
func TestMigrate(t *testing.T) {
@@ -962,23 +931,6 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
require.Equal(t, id, token)
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
@@ -1182,7 +1134,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
assert.NoError(t, err)
}
func TestSqlite_GetAccoundUsers(t *testing.T) {
func TestSqlStore_GetAccountUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
@@ -1371,6 +1323,14 @@ func TestSqlStore_SaveGroups(t *testing.T) {
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
require.NoError(t, err)
groups[1].Peers = []string{}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
require.NoError(t, err)
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID)
require.NoError(t, err)
require.Equal(t, groups[1], group)
}
func TestSqlStore_DeleteGroup(t *testing.T) {
@@ -2935,3 +2895,392 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) {
t.Logf("Test completed")
}
func TestSqlStore_GetAccountCreatedBy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectError bool
createdBy string
}{
{
name: "existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectError: false,
createdBy: "edafee4e-63fb-11ec-90d6-0242ac120003",
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectError: true,
},
{
name: "empty account ID",
accountID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Empty(t, createdBy)
} else {
require.NoError(t, err)
require.NotNil(t, createdBy)
require.Equal(t, tt.createdBy, createdBy)
}
})
}
}
func TestSqlStore_GetUserByUserID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
userID string
expectError bool
}{
{
name: "retrieve existing user",
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
expectError: false,
},
{
name: "retrieve non-existing user",
userID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty user ID",
userID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, user)
} else {
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, tt.userID, user.Id)
}
})
}
}
func TestSqlStore_GetUserByPATID(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err)
require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
}
func TestSqlStore_SaveUser(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
user := &types.User{
Id: "user-id",
AccountID: accountID,
Role: types.UserRoleAdmin,
IsServiceUser: false,
AutoGroups: []string{"groupA", "groupB"},
Blocked: false,
LastLogin: util.ToPtr(time.Now().UTC()),
CreatedAt: time.Now().UTC().Add(-time.Hour),
Issued: types.UserIssuedIntegration,
}
err = store.SaveUser(context.Background(), LockingStrengthUpdate, user)
require.NoError(t, err)
saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id)
require.NoError(t, err)
require.Equal(t, user.Id, saveUser.Id)
require.Equal(t, user.AccountID, saveUser.AccountID)
require.Equal(t, user.Role, saveUser.Role)
require.Equal(t, user.AutoGroups, saveUser.AutoGroups)
require.WithinDurationf(t, user.GetLastLogin(), saveUser.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal")
require.WithinDurationf(t, user.CreatedAt, saveUser.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.Equal(t, user.Issued, saveUser.Issued)
require.Equal(t, user.Blocked, saveUser.Blocked)
require.Equal(t, user.IsServiceUser, saveUser.IsServiceUser)
}
func TestSqlStore_SaveUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, accountUsers, 2)
users := []*types.User{
{
Id: "user-1",
AccountID: accountID,
Issued: "api",
AutoGroups: []string{"groupA", "groupB"},
},
{
Id: "user-2",
AccountID: accountID,
Issued: "integration",
AutoGroups: []string{"groupA"},
},
}
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
require.NoError(t, err)
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, accountUsers, 4)
users[1].AutoGroups = []string{"groupA", "groupC"}
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users)
require.NoError(t, err)
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, users[1].Id)
require.NoError(t, err)
require.Equal(t, users[1].AutoGroups, user.AutoGroups)
}
func TestSqlStore_DeleteUser(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID)
require.NoError(t, err)
user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID)
require.Error(t, err)
require.Nil(t, user)
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID)
require.NoError(t, err)
require.Len(t, userPATs, 0)
}
func TestSqlStore_GetPATByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
tests := []struct {
name string
patID string
expectError bool
}{
{
name: "retrieve existing PAT",
patID: "9dj38s35-63fb-11ec-90d6-0242ac120003",
expectError: false,
},
{
name: "retrieve non-existing PAT",
patID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty PAT ID",
patID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, pat)
} else {
require.NoError(t, err)
require.NotNil(t, pat)
require.Equal(t, tt.patID, pat.ID)
}
})
}
}
func TestSqlStore_GetUserPATs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003")
require.NoError(t, err)
require.Len(t, userPATs, 1)
}
func TestSqlStore_GetPATByHashedToken(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN")
require.NoError(t, err)
require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID)
}
func TestSqlStore_MarkPATUsed(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID)
require.NoError(t, err)
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
require.NoError(t, err)
now := time.Now().UTC()
require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now")
}
func TestSqlStore_SavePAT(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "edafee4e-63fb-11ec-90d6-0242ac120003"
pat := &types.PersonalAccessToken{
ID: "pat-id",
UserID: userID,
Name: "token",
HashedToken: "SoMeHaShEdToKeN",
ExpirationDate: util.ToPtr(time.Now().UTC().Add(12 * time.Hour)),
CreatedBy: userID,
CreatedAt: time.Now().UTC().Add(time.Hour),
LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)),
}
err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat)
require.NoError(t, err)
savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID)
require.NoError(t, err)
require.Equal(t, pat.ID, savePAT.ID)
require.Equal(t, pat.UserID, savePAT.UserID)
require.Equal(t, pat.HashedToken, savePAT.HashedToken)
require.Equal(t, pat.CreatedBy, savePAT.CreatedBy)
require.WithinDurationf(t, pat.GetExpirationDate(), savePAT.ExpirationDate.UTC(), time.Millisecond, "ExpirationDate should be equal")
require.WithinDurationf(t, pat.CreatedAt, savePAT.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal")
require.WithinDurationf(t, pat.GetLastUsed(), savePAT.LastUsed.UTC(), time.Millisecond, "LastUsed should be equal")
}
func TestSqlStore_DeletePAT(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
userID := "f4f6d672-63fb-11ec-90d6-0242ac120003"
patID := "9dj38s35-63fb-11ec-90d6-0242ac120003"
err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID)
require.NoError(t, err)
pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID)
require.Error(t, err)
require.Nil(t, pat)
}
func TestSqlStore_SaveUsers_LargeBatch(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, accountUsers, 2)
usersToSave := make([]*types.User, 0)
for i := 1; i <= 8000; i++ {
usersToSave = append(usersToSave, &types.User{
Id: fmt.Sprintf("user-%d", i),
AccountID: accountID,
Role: types.UserRoleUser,
})
}
err = store.SaveUsers(context.Background(), LockingStrengthUpdate, usersToSave)
require.NoError(t, err)
accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Equal(t, 8002, len(accountUsers))
}
func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
accountGroups, err := store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, accountGroups, 3)
groupsToSave := make([]*types.Group, 0)
for i := 1; i <= 8000; i++ {
groupsToSave = append(groupsToSave, &types.Group{
ID: fmt.Sprintf("%d", i),
AccountID: accountID,
Name: fmt.Sprintf("group-%d", i),
})
}
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groupsToSave)
require.NoError(t, err)
accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Equal(t, 8003, len(accountGroups))
}

View File

@@ -9,11 +9,16 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
@@ -59,21 +64,30 @@ type Store interface {
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
SaveAccount(ctx context.Context, account *types.Account) error
DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
SaveUsers(accountID string, users map[string]*types.User) error
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
@@ -184,6 +198,8 @@ const (
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
var supportedEngines = []Engine{SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine}
func getStoreEngineFromEnv() Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
@@ -192,7 +208,7 @@ func getStoreEngineFromEnv() Engine {
}
value := Engine(strings.ToLower(kind))
if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine {
if slices.Contains(supportedEngines, value) {
return value
}
@@ -319,7 +335,7 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), getGormConfig(kind))
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
if err != nil {
return nil, nil, err
}
@@ -340,51 +356,126 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
}
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
if kind == PostgresStoreEngine {
cleanUp, err := testutil.CreatePostgresTestContainer()
if err != nil {
return nil, nil, err
var cleanup func()
var err error
switch kind {
case PostgresStoreEngine:
store, cleanup, err = newReusedPostgresStore(ctx, store, kind)
case MysqlStoreEngine:
store, cleanup, err = newReusedMysqlStore(ctx, store, kind)
default:
cleanup = func() {
// sqlite doesn't need to be cleaned up
}
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, nil, err
}
return store, cleanUp, nil
}
if kind == MysqlStoreEngine {
cleanUp, err := testutil.CreateMysqlTestContainer()
if err != nil {
return nil, nil, err
}
dsn, ok := os.LookupEnv(mysqlDsnEnv)
if !ok {
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, nil, err
}
return store, cleanUp, nil
if err != nil {
return nil, cleanup, fmt.Errorf("failed to create test store: %v", err)
}
closeConnection := func() {
cleanup()
store.Close(ctx)
}
return store, closeConnection, nil
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(postgresDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreatePostgresTestContainer()
if err != nil {
return nil, nil, err
}
}
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
}
dsn, cleanup, err := createRandomDB(dsn, db, kind)
if err != nil {
return nil, cleanup, err
}
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, cleanup, err
}
return store, cleanup, nil
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind Engine) (*SqlStore, func(), error) {
if envDsn, ok := os.LookupEnv(mysqlDsnEnv); !ok || envDsn == "" {
var err error
_, err = testutil.CreateMysqlTestContainer()
if err != nil {
return nil, nil, err
}
}
dsn, ok := os.LookupEnv(mysqlDsnEnv)
if !ok {
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
if err != nil {
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
}
dsn, cleanup, err := createRandomDB(dsn, db, kind)
if err != nil {
return nil, cleanup, err
}
store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, nil, err
}
return store, cleanup, nil
}
func createRandomDB(dsn string, db *gorm.DB, engine Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
if err := db.Exec(fmt.Sprintf("CREATE DATABASE %s", dbName)).Error; err != nil {
return "", nil, fmt.Errorf("failed to create database: %v", err)
}
var err error
cleanup := func() {
switch engine {
case PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
case MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
}
if err != nil {
log.Errorf("failed to drop database %s: %v", dbName, err)
panic(err)
}
sqlDB, _ := db.DB()
_ = sqlDB.Close()
}
return replaceDBName(dsn, dbName), cleanup, nil
}
func replaceDBName(dsn, newDBName string) string {
re := regexp.MustCompile(`(?P<pre>[:/@])(?P<dbname>[^/?]+)(?P<post>\?|$)`)
return re.ReplaceAllString(dsn, `${pre}`+newDBName+`${post}`)
}
func loadSQL(db *gorm.DB, filepath string) error {
sqlContent, err := os.ReadFile(filepath)
if err != nil {

View File

@@ -37,7 +37,7 @@ CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`);
CREATE INDEX `idx_networks_id` ON `networks`(`id`);
CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0);
INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,'');

View File

@@ -22,7 +22,7 @@ func CreateMysqlTestContainer() (func(), error) {
myContainer, err := mysql.RunContainer(ctx,
testcontainers.WithImage("mlsmaycon/warmed-mysql:8"),
mysql.WithDatabase("testing"),
mysql.WithUsername("testing"),
mysql.WithUsername("root"),
mysql.WithPassword("testing"),
testcontainers.WithWaitStrategy(
wait.ForLog("/usr/sbin/mysqld: ready for connections").
@@ -34,6 +34,7 @@ func CreateMysqlTestContainer() (func(), error) {
}
cleanup := func() {
os.Unsetenv("NETBIRD_STORE_ENGINE_MYSQL_DSN")
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = myContainer.Terminate(timeoutCtx); err != nil {
@@ -68,6 +69,7 @@ func CreatePostgresTestContainer() (func(), error) {
}
cleanup := func() {
os.Unsetenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN")
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second)
defer cancelFunc()
if err = pgContainer.Terminate(timeoutCtx); err != nil {

View File

@@ -75,7 +75,7 @@ type PersonalAccessTokenGenerated struct {
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
func CreateNewPAT(name string, expirationInDays int, targetID, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken()
if err != nil {
return nil, err
@@ -84,6 +84,7 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona
return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(),
UserID: targetID,
Name: name,
HashedToken: hashedToken,
ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)),

View File

@@ -80,7 +80,7 @@ type User struct {
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
"github.com/netbirdio/netbird/management/server/util"
"golang.org/x/exp/maps"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
@@ -45,7 +46,7 @@ const (
)
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
@@ -53,13 +54,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err = store.SaveAccount(context.Background(), account)
err = s.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
am := DefaultAccountManager{
Store: store,
Store: s,
eventStore: &activity.InMemoryEventStore{},
}
@@ -81,7 +82,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}
@@ -855,7 +856,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
{
name: "Delete non-existent user",
userIDs: []string{"non-existent-user"},
expectedReasons: []string{"target user: non-existent-user not found"},
expectedReasons: []string{"user: non-existent-user not found"},
expectedNotDeleted: []string{},
},
{
@@ -867,7 +868,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
userInfos, err := am.BuildUserInfosForAccount(context.Background(), mockAccountID, mockUserID, maps.Values(account.Users))
assert.NoError(t, err)
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs, userInfos)
if len(tc.expectedReasons) > 0 {
assert.Error(t, err)
var foundExpectedErrors int