mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-27 12:46:39 +00:00
Revert "Merge branch 'main' into feature/remote-debug"
This reverts commit6d6333058c, reversing changes made to446aded1f7.
This commit is contained in:
@@ -111,6 +111,3 @@ Generate gRpc code:
|
||||
#!/bin/bash
|
||||
protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=.
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -26,11 +26,7 @@ func (s *BaseServer) JobManager() *server.JobManager {
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
context.Background(),
|
||||
s.PeersManager(),
|
||||
s.SettingsManager(),
|
||||
s.EventStore())
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
|
||||
if err != nil {
|
||||
log.Errorf("failed to create integrated peer validator: %v", err)
|
||||
}
|
||||
|
||||
@@ -105,8 +105,6 @@ type DefaultAccountManager struct {
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
loginFilter *loginFilter
|
||||
|
||||
disableDefaultPolicy bool
|
||||
}
|
||||
|
||||
@@ -216,7 +214,6 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
}
|
||||
|
||||
@@ -303,6 +300,9 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
// User that performs the update has to belong to the account.
|
||||
// Returns an updated Settings
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
@@ -348,17 +348,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
}
|
||||
|
||||
if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers || groupsUpdated {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return transaction.SaveAccountSettings(ctx, accountID, newSettings)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -502,6 +498,8 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
||||
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
expiredPeers, err := am.getExpiredPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -537,6 +535,9 @@ func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context
|
||||
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
|
||||
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
return func() (time.Duration, bool) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
inactivePeers, err := am.getInactivePeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
|
||||
@@ -677,6 +678,8 @@ func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheSto
|
||||
|
||||
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
|
||||
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1045,6 +1048,9 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
@@ -1137,20 +1143,12 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired {
|
||||
newUser.Blocked = true
|
||||
newUser.PendingApproval = true
|
||||
}
|
||||
|
||||
err = am.Store.SaveUser(ctx, newUser)
|
||||
err := am.Store.SaveUser(ctx, newUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1160,11 +1158,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
|
||||
return "", err
|
||||
}
|
||||
|
||||
if newUser.PendingApproval {
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
|
||||
} else {
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
}
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
}
|
||||
@@ -1363,6 +1357,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId)
|
||||
defer func() {
|
||||
if unlockAccount != nil {
|
||||
unlockAccount()
|
||||
}
|
||||
}()
|
||||
|
||||
var addNewGroups []string
|
||||
var removeOldGroups []string
|
||||
var hasChanges bool
|
||||
@@ -1425,6 +1426,8 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
unlockAccount()
|
||||
unlockAccount = nil
|
||||
|
||||
return nil
|
||||
})
|
||||
@@ -1633,16 +1636,17 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
|
||||
return am.loginFilter.allowLogin(wgPubKey, metahash)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||
@@ -1653,18 +1657,22 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
metahash := metaHash(meta, realIP.String())
|
||||
am.loginFilter.addLogin(peerPubKey, metahash)
|
||||
|
||||
return peer, netMap, postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
|
||||
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer accountUnlock()
|
||||
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer peerUnlock()
|
||||
|
||||
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
@@ -1673,6 +1681,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer unlockPeer()
|
||||
|
||||
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return mapError(ctx, err)
|
||||
@@ -1717,9 +1731,7 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account
|
||||
log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
if peer.UserID != "" {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
if len(peers) > 0 {
|
||||
err := am.expireAndUpdatePeers(ctx, accountID, peers)
|
||||
@@ -1815,9 +1827,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
RoutingPeerDNSResolutionEnabled: true,
|
||||
Extra: &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
},
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
@@ -1924,9 +1933,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
RoutingPeerDNSResolutionEnabled: true,
|
||||
Extra: &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2112,6 +2118,9 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate user permissions: %w", err)
|
||||
|
||||
@@ -32,8 +32,6 @@ type Manager interface {
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
|
||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||
RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
|
||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
|
||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
|
||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error)
|
||||
@@ -79,7 +77,7 @@ type Manager interface {
|
||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
|
||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error)
|
||||
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
|
||||
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
@@ -128,5 +126,4 @@ type Manager interface {
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
AllowSync(string, uint64) bool
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/prometheus/client_golang/prometheus/push"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -26,7 +25,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -3048,14 +3046,19 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark")
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > (maxExpected * 1.1) {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3118,14 +3121,19 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer")
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > (maxExpected * 1.1) {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3188,44 +3196,24 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > (maxExpected * 1.1) {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
exitCode := m.Run()
|
||||
|
||||
if exitCode == 0 && os.Getenv("CI") == "true" {
|
||||
runID := os.Getenv("GITHUB_RUN_ID")
|
||||
storeEngine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
err := push.New("http://localhost:9091", "account_manager_benchmark").
|
||||
Collector(testing_tools.BenchmarkDuration).
|
||||
Grouping("ci_run", runID).
|
||||
Grouping("store_engine", storeEngine).
|
||||
Push()
|
||||
if err != nil {
|
||||
log.Printf("Failed to push metrics: %v", err)
|
||||
} else {
|
||||
time.Sleep(1 * time.Minute)
|
||||
_ = push.New("http://localhost:9091", "account_manager_benchmark").
|
||||
Grouping("ci_run", runID).
|
||||
Grouping("store_engine", storeEngine).
|
||||
Delete()
|
||||
}
|
||||
}
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -3606,93 +3594,3 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
require.Error(t, err, "should fail with invalid peer ID")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a domain-based account with user approval enabled
|
||||
existingAccountID := "existing-account"
|
||||
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false)
|
||||
account.Settings.Extra = &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set the account as domain primary account
|
||||
account.IsDomainPrimaryAccount = true
|
||||
account.DomainCategory = types.PrivateCategory
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test adding new user to existing account with approval required
|
||||
newUserID := "new-user-id"
|
||||
userAuth := nbcontext.UserAuth{
|
||||
UserId: newUserID,
|
||||
Domain: "example.com",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
}
|
||||
|
||||
acc, err := manager.Store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain")
|
||||
require.Equal(t, "example.com", acc.Domain, "Account domain should match")
|
||||
|
||||
returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, existingAccountID, returnedAccountID)
|
||||
|
||||
// Verify user was created with pending approval
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, user.Blocked, "User should be blocked when approval is required")
|
||||
assert.True(t, user.PendingApproval, "User should be pending approval")
|
||||
assert.Equal(t, existingAccountID, user.AccountID)
|
||||
}
|
||||
|
||||
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a domain-based account without user approval
|
||||
ownerUserAuth := nbcontext.UserAuth{
|
||||
UserId: "owner-user",
|
||||
Domain: "example.com",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
}
|
||||
existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify the account to disable user approval
|
||||
account, err := manager.Store.GetAccount(context.Background(), existingAccountID)
|
||||
require.NoError(t, err)
|
||||
account.Settings.Extra = &types.ExtraSettings{
|
||||
UserApprovalRequired: false,
|
||||
}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test adding new user to existing account without approval required
|
||||
newUserID := "new-user-id"
|
||||
userAuth := nbcontext.UserAuth{
|
||||
UserId: newUserID,
|
||||
Domain: "example.com",
|
||||
DomainCategory: types.PrivateCategory,
|
||||
}
|
||||
|
||||
returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, existingAccountID, returnedAccountID)
|
||||
|
||||
// Verify user was created without pending approval
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, user.Blocked, "User should not be blocked when approval is not required")
|
||||
assert.False(t, user.PendingApproval, "User should not be pending approval")
|
||||
assert.Equal(t, existingAccountID, user.AccountID)
|
||||
}
|
||||
|
||||
@@ -177,8 +177,6 @@ const (
|
||||
|
||||
AccountNetworkRangeUpdated Activity = 87
|
||||
PeerIPUpdated Activity = 88
|
||||
UserApproved Activity = 89
|
||||
UserRejected Activity = 90
|
||||
|
||||
JobCreatedByUser Activity = 89
|
||||
|
||||
@@ -290,9 +288,6 @@ var activityMap = map[Activity]Code{
|
||||
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||
|
||||
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
|
||||
|
||||
UserApproved: {"User approved", "user.approve"},
|
||||
UserRejected: {"User rejected", "user.reject"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -63,10 +63,12 @@ type Validator struct {
|
||||
}
|
||||
|
||||
var (
|
||||
errKeyNotFound = errors.New("unable to find appropriate key")
|
||||
errTokenEmpty = errors.New("required authorization token not found")
|
||||
errTokenInvalid = errors.New("token is invalid")
|
||||
errTokenParsing = errors.New("token could not be parsed")
|
||||
errKeyNotFound = errors.New("unable to find appropriate key")
|
||||
errInvalidAudience = errors.New("invalid audience")
|
||||
errInvalidIssuer = errors.New("invalid issuer")
|
||||
errTokenEmpty = errors.New("required authorization token not found")
|
||||
errTokenInvalid = errors.New("token is invalid")
|
||||
errTokenParsing = errors.New("token could not be parsed")
|
||||
)
|
||||
|
||||
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
|
||||
@@ -86,6 +88,24 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
|
||||
|
||||
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify 'aud' claim
|
||||
var checkAud bool
|
||||
for _, audience := range v.audienceList {
|
||||
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
|
||||
if checkAud {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !checkAud {
|
||||
return token, errInvalidAudience
|
||||
}
|
||||
|
||||
// Verify 'issuer' claim
|
||||
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false)
|
||||
if !checkIss {
|
||||
return token, errInvalidIssuer
|
||||
}
|
||||
|
||||
// If keys are rotated, verify the keys prior to token validation
|
||||
if v.idpSignkeyRefreshEnabled {
|
||||
// If the keys are invalid, retrieve new ones
|
||||
@@ -124,7 +144,7 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// If we get here, the required token is missing
|
||||
@@ -133,13 +153,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
parsedToken, err := jwt.Parse(
|
||||
token,
|
||||
v.getKeyFunc(ctx),
|
||||
jwt.WithAudience(v.audienceList...),
|
||||
jwt.WithIssuer(v.issuer),
|
||||
jwt.WithIssuedAt(),
|
||||
)
|
||||
parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx))
|
||||
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
|
||||
@@ -3,7 +3,7 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
|
||||
@@ -20,9 +20,29 @@ import (
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetCustomZone retrieves a cached custom zone
|
||||
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.CustomZones.Load(key); ok {
|
||||
return value.(*proto.CustomZone), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetCustomZone stores a custom zone in the cache
|
||||
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.CustomZones.Store(key, value)
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
@@ -93,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
if err = transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -192,8 +212,14 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
cacheKey := zone.Domain
|
||||
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||
} else {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
cache.SetCustomZone(cacheKey, protoZone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
|
||||
@@ -474,6 +474,15 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
// Verify that the cache contains elements from both configs
|
||||
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.com")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.org")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
@@ -67,6 +67,9 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
|
||||
|
||||
// CreateGroup object of the peers
|
||||
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -93,6 +96,10 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -102,8 +109,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -122,6 +128,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
|
||||
// UpdateGroup object of the peers
|
||||
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -167,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.UpdateGroup(ctx, newGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -202,45 +211,35 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*types.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groupIDs) == 1 {
|
||||
return err
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue updating other groups
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.CreateGroups(ctx, accountID, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -253,7 +252,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateGroups updates groups in the account.
|
||||
@@ -270,45 +269,35 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var groupsToSave []*types.Group
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groupsToSave = append(groupsToSave, newGroup)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groups) == 1 {
|
||||
return err
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue updating other groups
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.UpdateGroups(ctx, accountID, groupsToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -321,7 +310,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
@@ -393,6 +382,8 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||
}
|
||||
|
||||
@@ -432,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -451,6 +442,9 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
@@ -460,11 +454,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -479,6 +473,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
@@ -498,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.UpdateGroup(ctx, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -517,6 +514,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
@@ -526,11 +526,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -545,6 +545,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
@@ -564,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.UpdateGroup(ctx, group)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -604,6 +607,13 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
newGroup.ID = xid.New().String()
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.SkipAutoApply,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -41,30 +40,21 @@ import (
|
||||
internalStatus "github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
type GRPCServer struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
jobManager *JobManager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
peersUpdateManager *PeersUpdateManager
|
||||
jobManager *JobManager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -95,24 +85,19 @@ func NewServer(
|
||||
}
|
||||
}
|
||||
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
jobManager: jobManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
jobManager: jobManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -192,6 +177,9 @@ func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error {
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
reqStart := time.Now()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
@@ -200,27 +188,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
}
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
@@ -244,12 +211,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
return mapError(ctx, err)
|
||||
@@ -267,7 +236,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
unlock()
|
||||
@@ -366,7 +335,6 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
|
||||
}
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -505,9 +473,6 @@ func mapError(ctx context.Context, err error) error {
|
||||
default:
|
||||
}
|
||||
}
|
||||
if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) {
|
||||
return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error())
|
||||
}
|
||||
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
|
||||
return status.Errorf(codes.Internal, "failed handling request")
|
||||
}
|
||||
@@ -599,9 +564,16 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
reqStart := time.Now()
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
sRealIP := realIP.String()
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
||||
@@ -609,24 +581,6 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
return nil, internalStatus.ErrPeerAlreadyLoggedIn
|
||||
}
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
@@ -637,12 +591,6 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
@@ -663,7 +611,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: peerMeta,
|
||||
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
|
||||
UserID: userID,
|
||||
SetupKey: loginReq.GetSetupKey(),
|
||||
ConnectionIP: realIP,
|
||||
@@ -1129,6 +1077,8 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
|
||||
@@ -11,11 +11,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -198,7 +198,6 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Settings.Extra != nil {
|
||||
settings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
@@ -328,7 +327,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
if settings.Extra != nil {
|
||||
apiSettings.Extra = &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: settings.Extra.UserApprovalRequired,
|
||||
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
|
||||
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
|
||||
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
|
||||
@@ -488,33 +488,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
}
|
||||
|
||||
return &api.PeerBatch{
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.GetLastLogin(),
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
AccessiblePeersCount: accessiblePeersCount,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.GetLastLogin(),
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
AccessiblePeersCount: accessiblePeersCount,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,19 +8,17 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const failedToConvertRoute = "failed to convert route to response: %v"
|
||||
|
||||
const exitNodeCIDR = "0.0.0.0/0"
|
||||
|
||||
// handler is the routes handler of the account
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
@@ -126,16 +124,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
accessControlGroupIds = *req.AccessControlGroups
|
||||
}
|
||||
|
||||
// Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes)
|
||||
skipAutoApply := false
|
||||
if req.SkipAutoApply != nil {
|
||||
skipAutoApply = *req.SkipAutoApply
|
||||
} else if newPrefix.String() == exitNodeCIDR {
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -152,31 +142,23 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||
return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId)
|
||||
}
|
||||
|
||||
func (h *handler) validateRouteUpdate(req api.PutApiRoutesRouteIdJSONRequestBody) error {
|
||||
return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId)
|
||||
}
|
||||
|
||||
func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *string, peerGroups *[]string, networkId string) error {
|
||||
if network != nil && domains != nil {
|
||||
if req.Network != nil && req.Domains != nil {
|
||||
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
|
||||
}
|
||||
|
||||
if network == nil && domains == nil {
|
||||
if req.Network == nil && req.Domains == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided")
|
||||
}
|
||||
|
||||
if peer == nil && peerGroups == nil {
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if peer != nil && peerGroups != nil {
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(networkId) > route.MaxNetIDChar || networkId == "" {
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters",
|
||||
route.MaxNetIDChar)
|
||||
}
|
||||
@@ -213,7 +195,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.validateRouteUpdate(req); err != nil {
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -223,24 +205,15 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
peerID = *req.Peer
|
||||
}
|
||||
|
||||
// Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes)
|
||||
skipAutoApply := false
|
||||
if req.SkipAutoApply != nil {
|
||||
skipAutoApply = *req.SkipAutoApply
|
||||
} else if req.Network != nil && *req.Network == exitNodeCIDR {
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID(routeID),
|
||||
NetID: route.NetID(req.NetworkId),
|
||||
Masquerade: req.Masquerade,
|
||||
Metric: req.Metric,
|
||||
Description: req.Description,
|
||||
Enabled: req.Enabled,
|
||||
Groups: req.Groups,
|
||||
KeepRoute: req.KeepRoute,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
ID: route.ID(routeID),
|
||||
NetID: route.NetID(req.NetworkId),
|
||||
Masquerade: req.Masquerade,
|
||||
Metric: req.Metric,
|
||||
Description: req.Description,
|
||||
Enabled: req.Enabled,
|
||||
Groups: req.Groups,
|
||||
KeepRoute: req.KeepRoute,
|
||||
}
|
||||
|
||||
if req.Domains != nil {
|
||||
@@ -348,19 +321,18 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
|
||||
}
|
||||
network := serverRoute.Network.String()
|
||||
route := &api.Route{
|
||||
Id: string(serverRoute.ID),
|
||||
Description: serverRoute.Description,
|
||||
NetworkId: string(serverRoute.NetID),
|
||||
Enabled: serverRoute.Enabled,
|
||||
Peer: &serverRoute.Peer,
|
||||
Network: &network,
|
||||
Domains: &domains,
|
||||
NetworkType: serverRoute.NetworkType.String(),
|
||||
Masquerade: serverRoute.Masquerade,
|
||||
Metric: serverRoute.Metric,
|
||||
Groups: serverRoute.Groups,
|
||||
KeepRoute: serverRoute.KeepRoute,
|
||||
SkipAutoApply: &serverRoute.SkipAutoApply,
|
||||
Id: string(serverRoute.ID),
|
||||
Description: serverRoute.Description,
|
||||
NetworkId: string(serverRoute.NetID),
|
||||
Enabled: serverRoute.Enabled,
|
||||
Peer: &serverRoute.Peer,
|
||||
Network: &network,
|
||||
Domains: &domains,
|
||||
NetworkType: serverRoute.NetworkType.String(),
|
||||
Masquerade: serverRoute.Masquerade,
|
||||
Metric: serverRoute.Metric,
|
||||
Groups: serverRoute.Groups,
|
||||
KeepRoute: serverRoute.KeepRoute,
|
||||
}
|
||||
|
||||
if len(serverRoute.PeerGroups) > 0 {
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -62,22 +62,21 @@ func initRoutesTestData() *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||
switch routeID {
|
||||
case existingRouteID:
|
||||
if routeID == existingRouteID {
|
||||
return baseExistingRoute, nil
|
||||
case existingRouteID2:
|
||||
}
|
||||
if routeID == existingRouteID2 {
|
||||
route := baseExistingRoute.Copy()
|
||||
route.PeerGroups = []string{existingGroupID}
|
||||
return route, nil
|
||||
case existingRouteID3:
|
||||
} else if routeID == existingRouteID3 {
|
||||
route := baseExistingRoute.Copy()
|
||||
route.Domains = domain.List{existingDomain}
|
||||
return route, nil
|
||||
default:
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
},
|
||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool, skipAutoApply bool) (*route.Route, error) {
|
||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
if peerID == notFoundPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
@@ -104,7 +103,6 @@ func initRoutesTestData() *handler {
|
||||
Groups: groups,
|
||||
KeepRoute: keepRoute,
|
||||
AccessControlGroups: accessControlGroups,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}, nil
|
||||
},
|
||||
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
|
||||
@@ -192,20 +190,19 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"],"skip_auto_apply":false}`, existingPeerID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -213,22 +210,21 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "domainNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
KeepRoute: true,
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "domainNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
KeepRoute: true,
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -236,7 +232,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"],\"skip_auto_apply\":false}", existingPeerID, existingGroupID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
@@ -250,7 +246,6 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
AccessControlGroups: &[]string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -341,63 +336,60 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
name: "Network PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"is_selected\":true}", existingPeerID, existingGroupID)),
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Domains PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID)),
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
KeepRoute: true,
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
KeepRoute: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PUT OK when peer_groups provided",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"],\"skip_auto_apply\":false}", existingGroupID, existingGroupID)),
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &emptyString,
|
||||
PeerGroups: &[]string{existingGroupID},
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &emptyString,
|
||||
PeerGroups: &[]string{existingGroupID},
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
@@ -31,8 +31,6 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
@@ -325,76 +323,17 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
}
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
PendingApproval: user.PendingApproval,
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
}
|
||||
}
|
||||
|
||||
// approveUser is a POST request to approve a user that is pending approval
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
userResponse := toUserResponse(user, userAuth.UserId)
|
||||
util.WriteJSONObject(r.Context(), w, userResponse)
|
||||
}
|
||||
|
||||
// rejectUser is a DELETE request to reject a user that is pending approval
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -725,133 +725,3 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri
|
||||
}
|
||||
return modules
|
||||
}
|
||||
|
||||
func TestApproveUserEndpoint(t *testing.T) {
|
||||
adminUser := &types.User{
|
||||
Id: "admin-user",
|
||||
Role: types.UserRoleAdmin,
|
||||
AccountID: existingAccountID,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
pendingUser := &types.User{
|
||||
Id: "pending-user",
|
||||
Role: types.UserRoleUser,
|
||||
AccountID: existingAccountID,
|
||||
Blocked: true,
|
||||
PendingApproval: true,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestingUser *types.User
|
||||
}{
|
||||
{
|
||||
name: "approve user as admin should return 200",
|
||||
expectedStatus: 200,
|
||||
expectedBody: true,
|
||||
requestingUser: adminUser,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{}
|
||||
am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
|
||||
approvedUserInfo := &types.UserInfo{
|
||||
ID: pendingUser.Id,
|
||||
Email: "pending@example.com",
|
||||
Name: "Pending User",
|
||||
Role: string(pendingUser.Role),
|
||||
AutoGroups: []string{},
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
PendingApproval: false,
|
||||
LastLogin: time.Now(),
|
||||
Issued: types.UserIssuedAPI,
|
||||
}
|
||||
return approvedUserInfo, nil
|
||||
}
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedBody {
|
||||
var response api.User
|
||||
err = json.Unmarshal(rr.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pending-user", response.Id)
|
||||
assert.False(t, response.IsBlocked)
|
||||
assert.False(t, response.PendingApproval)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectUserEndpoint(t *testing.T) {
|
||||
adminUser := &types.User{
|
||||
Id: "admin-user",
|
||||
Role: types.UserRoleAdmin,
|
||||
AccountID: existingAccountID,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestingUser *types.User
|
||||
}{
|
||||
{
|
||||
name: "reject user as admin should return 200",
|
||||
expectedStatus: 200,
|
||||
requestingUser: adminUser,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{}
|
||||
am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
|
||||
|
||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
|
||||
@@ -8,15 +8,16 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -17,9 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
const modulePeers = "peers"
|
||||
@@ -48,7 +47,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -66,7 +65,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -83,7 +82,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -93,7 +92,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -110,7 +109,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -120,7 +119,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -137,7 +136,7 @@ func BenchmarkDeletePeer(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -147,7 +146,7 @@ func BenchmarkDeletePeer(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
// Map to store peers, groups, users, and setupKeys by name
|
||||
@@ -48,7 +47,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -70,7 +69,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -87,7 +86,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -110,7 +109,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,7 +126,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -137,7 +136,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -154,7 +153,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -164,7 +163,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -181,7 +180,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -191,7 +190,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
const moduleUsers = "users"
|
||||
@@ -47,7 +46,7 @@ func BenchmarkUpdateUser(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -72,7 +71,7 @@ func BenchmarkUpdateUser(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,18 +84,18 @@ func BenchmarkGetOneUser(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -111,18 +110,18 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -137,7 +136,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -148,7 +147,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,9 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_SetupKeys_Create(t *testing.T) {
|
||||
@@ -288,7 +287,7 @@ func Test_SetupKeys_Create(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
@@ -573,7 +572,7 @@ func Test_SetupKeys_Update(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
@@ -752,7 +751,7 @@ func Test_SetupKeys_Get(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
|
||||
|
||||
@@ -904,7 +903,7 @@ func Test_SetupKeys_GetAll(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId)
|
||||
|
||||
@@ -1088,7 +1087,7 @@ func Test_SetupKeys_Delete(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
|
||||
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test store: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create metrics: %v", err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
|
||||
done := make(chan struct{})
|
||||
if validateUpdate {
|
||||
go func() {
|
||||
if expectedPeerUpdate != nil {
|
||||
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
|
||||
} else {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
}
|
||||
|
||||
geoMock := &geolocation.Mock{}
|
||||
validatorMock := server.MockIntegratedValidator{}
|
||||
proxyController := integrations.NewController(store)
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: authManager.MarkPATUsed,
|
||||
GetPATInfoFunc: authManager.GetPATInfo,
|
||||
}
|
||||
|
||||
networksManagerMock := networks.NewManagerMock()
|
||||
resourcesManagerMock := resources.NewManagerMock()
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
return apiHandler, am, done
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
t.Errorf("Unexpected message received: %+v", msg)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
assert.Equal(t, expected, msg)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
userAuth.UserId = token
|
||||
userAuth.AccountId = "testAccountId"
|
||||
userAuth.Domain = "test.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "otherUserId":
|
||||
userAuth.UserId = "otherUserId"
|
||||
userAuth.AccountId = "otherAccountId"
|
||||
userAuth.Domain = "other.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "invalidToken":
|
||||
return userAuth, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
jwtToken := jwt.New(jwt.SigningMethodHS256)
|
||||
return userAuth, jwtToken, nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package testing_tools
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -13,12 +14,32 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
@@ -202,11 +223,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta
|
||||
return content, expectedStatus == http.StatusOK
|
||||
}
|
||||
|
||||
func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) {
|
||||
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
acc, err := am.GetAccount(ctx, TestAccountId)
|
||||
account, err := am.GetAccount(ctx, TestAccountId)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
@@ -222,23 +243,23 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true},
|
||||
UserID: TestUserId,
|
||||
}
|
||||
acc.Peers[peer.ID] = peer
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
// Create users
|
||||
for i := 0; i < users; i++ {
|
||||
user := &types.User{
|
||||
Id: fmt.Sprintf("olduser-%d", i),
|
||||
AccountID: acc.Id,
|
||||
AccountID: account.Id,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
acc.Users[user.Id] = user
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
for i := 0; i < setupKeys; i++ {
|
||||
key := &types.SetupKey{
|
||||
Id: fmt.Sprintf("oldkey-%d", i),
|
||||
AccountID: acc.Id,
|
||||
AccountID: account.Id,
|
||||
AutoGroups: []string{"someGroupID"},
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)),
|
||||
@@ -246,11 +267,11 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
Type: "reusable",
|
||||
UsageLimit: 0,
|
||||
}
|
||||
acc.SetupKeys[key.Id] = key
|
||||
account.SetupKeys[key.Id] = key
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
acc.Policies = make([]*types.Policy, 0, groups)
|
||||
account.Policies = make([]*types.Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &types.Group{
|
||||
@@ -261,7 +282,7 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
acc.Groups[groupID] = group
|
||||
account.Groups[groupID] = group
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &types.Policy{
|
||||
@@ -281,10 +302,10 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
},
|
||||
},
|
||||
}
|
||||
acc.Policies = append(acc.Policies, policy)
|
||||
account.Policies = append(account.Policies, policy)
|
||||
}
|
||||
|
||||
acc.PostureChecks = []*posture.Checks{
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
@@ -296,38 +317,52 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
},
|
||||
}
|
||||
|
||||
store := am.GetStore()
|
||||
|
||||
err = store.SaveAccount(context.Background(), acc)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to save account: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
|
||||
b.Helper()
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
|
||||
}
|
||||
|
||||
EvaluateBenchmarkResults(b, testCase, duration, module, operation)
|
||||
|
||||
}
|
||||
|
||||
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) {
|
||||
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
|
||||
b.Helper()
|
||||
|
||||
branch := os.Getenv("GIT_BRANCH")
|
||||
if branch == "" && os.Getenv("CI") == "true" {
|
||||
if branch == "" {
|
||||
b.Fatalf("environment variable GIT_BRANCH is not set")
|
||||
}
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
|
||||
}
|
||||
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
|
||||
gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch)
|
||||
gauge.Set(msPerOp)
|
||||
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
userAuth.UserId = token
|
||||
userAuth.AccountId = "testAccountId"
|
||||
userAuth.Domain = "test.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "otherUserId":
|
||||
userAuth.UserId = "otherUserId"
|
||||
userAuth.AccountId = "otherAccountId"
|
||||
userAuth.Domain = "other.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "invalidToken":
|
||||
return userAuth, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
jwtToken := jwt.New(jwt.SigningMethodHS256)
|
||||
return userAuth, jwtToken, nil
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -231,7 +231,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
@@ -11,11 +11,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
|
||||
@@ -2,7 +2,6 @@ package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
|
||||
@@ -166,7 +166,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -11,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -168,7 +168,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -158,7 +158,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -13,6 +12,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -253,7 +253,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
@@ -46,6 +46,9 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
|
||||
groups = []string{}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,14 +3,12 @@ package port_forwarding
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type Controller interface {
|
||||
SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer)
|
||||
GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error)
|
||||
GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error)
|
||||
SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string)
|
||||
GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error)
|
||||
IsPeerInIngressPorts(ctx context.Context, accountID, peerID string) (bool, error)
|
||||
}
|
||||
|
||||
@@ -21,15 +19,11 @@ func NewControllerMock() *ControllerMock {
|
||||
return &ControllerMock{}
|
||||
}
|
||||
|
||||
func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) {
|
||||
func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) {
|
||||
return make(map[string]*nbtypes.NetworkMap), nil
|
||||
}
|
||||
|
||||
func (c *ControllerMock) GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) {
|
||||
func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) {
|
||||
return make(map[string]*nbtypes.NetworkMap), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
|
||||
type lfConfig struct {
|
||||
reconnThreshold time.Duration
|
||||
baseBlockDuration time.Duration
|
||||
reconnLimitForBan int
|
||||
metaChangeLimit int
|
||||
}
|
||||
|
||||
func initCfg() *lfConfig {
|
||||
return &lfConfig{
|
||||
reconnThreshold: reconnThreshold,
|
||||
baseBlockDuration: baseBlockDuration,
|
||||
reconnLimitForBan: reconnLimitForBan,
|
||||
metaChangeLimit: metaChangeLimit,
|
||||
}
|
||||
}
|
||||
|
||||
type loginFilter struct {
|
||||
mu sync.RWMutex
|
||||
cfg *lfConfig
|
||||
logged map[string]*peerState
|
||||
}
|
||||
|
||||
type peerState struct {
|
||||
currentHash uint64
|
||||
sessionCounter int
|
||||
sessionStart time.Time
|
||||
lastSeen time.Time
|
||||
isBanned bool
|
||||
banLevel int
|
||||
banExpiresAt time.Time
|
||||
metaChangeCounter int
|
||||
metaChangeWindowStart time.Time
|
||||
}
|
||||
|
||||
func newLoginFilter() *loginFilter {
|
||||
return newLoginFilterWithCfg(initCfg())
|
||||
}
|
||||
|
||||
func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter {
|
||||
return &loginFilter{
|
||||
logged: make(map[string]*peerState),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool {
|
||||
l.mu.RLock()
|
||||
defer func() {
|
||||
l.mu.RUnlock()
|
||||
}()
|
||||
state, ok := l.logged[wgPubKey]
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if state.isBanned && time.Now().Before(state.banExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if metaHash != state.currentHash {
|
||||
if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
now := time.Now()
|
||||
l.mu.Lock()
|
||||
defer func() {
|
||||
l.mu.Unlock()
|
||||
}()
|
||||
|
||||
state, ok := l.logged[wgPubKey]
|
||||
|
||||
if !ok {
|
||||
l.logged[wgPubKey] = &peerState{
|
||||
currentHash: metaHash,
|
||||
sessionCounter: 1,
|
||||
sessionStart: now,
|
||||
lastSeen: now,
|
||||
metaChangeWindowStart: now,
|
||||
metaChangeCounter: 1,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if state.isBanned && now.After(state.banExpiresAt) {
|
||||
state.isBanned = false
|
||||
}
|
||||
|
||||
if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) {
|
||||
state.banLevel = 0
|
||||
}
|
||||
|
||||
if metaHash != state.currentHash {
|
||||
if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) {
|
||||
state.metaChangeWindowStart = now
|
||||
state.metaChangeCounter = 1
|
||||
} else {
|
||||
state.metaChangeCounter++
|
||||
}
|
||||
state.currentHash = metaHash
|
||||
state.sessionCounter = 1
|
||||
state.sessionStart = now
|
||||
state.lastSeen = now
|
||||
return
|
||||
}
|
||||
|
||||
state.sessionCounter++
|
||||
if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold {
|
||||
state.isBanned = true
|
||||
state.banLevel++
|
||||
|
||||
backoffFactor := math.Pow(2, float64(state.banLevel-1))
|
||||
duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor)
|
||||
state.banExpiresAt = now.Add(duration)
|
||||
|
||||
state.sessionCounter = 0
|
||||
state.sessionStart = now
|
||||
}
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func testAdvancedCfg() *lfConfig {
|
||||
return &lfConfig{
|
||||
reconnThreshold: 50 * time.Millisecond,
|
||||
baseBlockDuration: 100 * time.Millisecond,
|
||||
reconnLimitForBan: 3,
|
||||
metaChangeLimit: 2,
|
||||
}
|
||||
}
|
||||
|
||||
type LoginFilterTestSuite struct {
|
||||
suite.Suite
|
||||
filter *loginFilter
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) SetupTest() {
|
||||
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
|
||||
}
|
||||
|
||||
func TestLoginFilterTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(LoginFilterTestSuite))
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta))
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
limit := s.filter.cfg.reconnLimitForBan
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
|
||||
s.False(s.filter.allowLogin(pubKey, meta))
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
limit := s.filter.cfg.reconnLimitForBan
|
||||
baseBan := s.filter.cfg.baseBlockDuration
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(1, s.filter.logged[pubKey].banLevel)
|
||||
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
|
||||
|
||||
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
|
||||
s.filter.logged[pubKey].isBanned = false
|
||||
|
||||
for i := 0; i <= limit; i++ {
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
}
|
||||
s.True(s.filter.logged[pubKey].isBanned)
|
||||
s.Equal(2, s.filter.logged[pubKey].banLevel)
|
||||
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
|
||||
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
|
||||
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.filter.logged[pubKey] = &peerState{
|
||||
isBanned: true,
|
||||
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
|
||||
}
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta))
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.False(s.filter.logged[pubKey].isBanned)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta := uint64(1)
|
||||
|
||||
s.filter.logged[pubKey] = &peerState{
|
||||
currentHash: meta,
|
||||
banLevel: 3,
|
||||
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
|
||||
}
|
||||
|
||||
s.filter.addLogin(pubKey, meta)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(0, s.filter.logged[pubKey].banLevel)
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
limit := s.filter.cfg.metaChangeLimit
|
||||
|
||||
for i := range limit {
|
||||
s.filter.addLogin(pubKey, uint64(i+1))
|
||||
}
|
||||
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
|
||||
|
||||
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
|
||||
|
||||
s.False(isAllowed, "should block new meta hash after limit is reached")
|
||||
}
|
||||
|
||||
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
|
||||
pubKey := "PUB_KEY_A"
|
||||
meta1 := uint64(1)
|
||||
meta2 := uint64(2)
|
||||
meta3 := uint64(3)
|
||||
|
||||
s.filter.addLogin(pubKey, meta1)
|
||||
s.filter.addLogin(pubKey, meta2)
|
||||
s.Require().Contains(s.filter.logged, pubKey)
|
||||
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
|
||||
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
|
||||
|
||||
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
|
||||
|
||||
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
|
||||
|
||||
s.filter.addLogin(pubKey, meta3)
|
||||
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
|
||||
}
|
||||
|
||||
func BenchmarkHashingMethods(b *testing.B) {
|
||||
meta := nbpeer.PeerSystemMeta{
|
||||
WtVersion: "1.25.1",
|
||||
OSVersion: "Ubuntu 22.04.3 LTS",
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
|
||||
b.Run("BuilderString", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FnvHashToString", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
_ = resultString
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
|
||||
b.WriteString(meta.WtVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.OSVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.KernelVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
pubKeys := make([]string, numKeys)
|
||||
for i := range numKeys {
|
||||
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
for pb.Next() {
|
||||
key := pubKeys[r.Intn(numKeys)]
|
||||
meta := r.Uint64()
|
||||
|
||||
if filter.allowLogin(key, meta) {
|
||||
filter.addLogin(key, meta)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -61,7 +61,7 @@ type MockAccountManager struct {
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
@@ -95,8 +95,6 @@ type MockAccountManager struct {
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() account.ExternalCacheManager
|
||||
@@ -518,9 +516,9 @@ func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) {
|
||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
if am.CreateRouteFunc != nil {
|
||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute, isSelected)
|
||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||
}
|
||||
@@ -631,20 +629,6 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string,
|
||||
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
|
||||
if am.ApproveUserFunc != nil {
|
||||
return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
|
||||
if am.RejectUserFunc != nil {
|
||||
return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented")
|
||||
}
|
||||
|
||||
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if am.GetNameServerGroupFunc != nil {
|
||||
@@ -993,10 +977,3 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
|
||||
if am.AllowSyncFunc != nil {
|
||||
return am.AllowSyncFunc(key, hash)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -70,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.SaveNameServerGroup(ctx, newNSGroup)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -91,6 +94,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
@@ -121,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.SaveNameServerGroup(ctx, nsGroupToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -142,6 +148,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -164,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -70,6 +70,9 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
|
||||
|
||||
network.ID = xid.New().String()
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
|
||||
defer unlock()
|
||||
|
||||
err = m.store.SaveNetwork(ctx, network)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save network: %w", err)
|
||||
@@ -101,6 +104,9 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
|
||||
defer unlock()
|
||||
|
||||
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network: %w", err)
|
||||
@@ -125,6 +131,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
return fmt.Errorf("failed to get network: %w", err)
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
@@ -158,15 +167,15 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
return fmt.Errorf("failed to delete network: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta())
|
||||
})
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta())
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -108,6 +108,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
return nil, fmt.Errorf("failed to create new network resource: %w", err)
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID)
|
||||
defer unlock()
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
@@ -201,6 +204,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
resource.Domain = domain
|
||||
resource.Prefix = prefix
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID)
|
||||
defer unlock()
|
||||
|
||||
var eventsToStore []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
@@ -309,6 +315,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var events []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||
|
||||
@@ -88,6 +88,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID)
|
||||
defer unlock()
|
||||
|
||||
var network *networkTypes.Network
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
@@ -154,6 +157,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID)
|
||||
defer unlock()
|
||||
|
||||
var network *networkTypes.Network
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
@@ -197,6 +203,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var event func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
|
||||
|
||||
@@ -192,6 +192,9 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -458,6 +461,9 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID,
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -480,7 +486,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -494,6 +500,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
|
||||
return fmt.Errorf("failed to remove peer from groups: %w", err)
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
@@ -543,7 +553,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
}
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, err
|
||||
@@ -615,9 +625,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
|
||||
}
|
||||
if user.PendingApproval {
|
||||
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
|
||||
}
|
||||
groupsToAdd = user.AutoGroups
|
||||
opEvent.InitiatorID = userID
|
||||
opEvent.Activity = activity.PeerAddedByUser
|
||||
@@ -728,6 +735,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
newPeer.DNSLabel = freeLabel
|
||||
newPeer.IP = freeIP
|
||||
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
err = transaction.AddPeerToAccount(ctx, newPeer)
|
||||
if err != nil {
|
||||
@@ -779,10 +793,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
unlock()
|
||||
unlock = nil
|
||||
break
|
||||
}
|
||||
|
||||
if isUniqueConstraintError(err) {
|
||||
unlock()
|
||||
unlock = nil
|
||||
log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err)
|
||||
continue
|
||||
}
|
||||
@@ -941,6 +959,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
}
|
||||
|
||||
unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
|
||||
defer func() {
|
||||
if unlockPeer != nil {
|
||||
unlockPeer()
|
||||
}
|
||||
}()
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var updateRemotePeers bool
|
||||
var isRequiresApproval bool
|
||||
@@ -1021,6 +1048,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
unlockPeer()
|
||||
unlockPeer = nil
|
||||
|
||||
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
@@ -1152,7 +1182,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return nil, nil, nil, err
|
||||
@@ -1325,7 +1355,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return
|
||||
@@ -1464,7 +1494,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return
|
||||
@@ -1652,7 +1682,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
}
|
||||
dnsDomain := am.GetDNSDomain(settings)
|
||||
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ type Peer struct {
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string `gorm:"index"`
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string // uniqueness index per accountID (check migrations)
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -990,14 +989,19 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
minExpected := bc.minMsPerOpLocal
|
||||
maxExpected := bc.maxMsPerOpLocal
|
||||
if os.Getenv("CI") == "true" {
|
||||
minExpected = bc.minMsPerOpCICD
|
||||
maxExpected = bc.maxMsPerOpCICD
|
||||
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
|
||||
}
|
||||
|
||||
if msPerOp > maxExpected {
|
||||
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
if msPerOp < minExpected {
|
||||
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||
}
|
||||
|
||||
if msPerOp > (maxExpected * 1.1) {
|
||||
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1605,6 +1609,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupKey string
|
||||
wireGuardPubKey string
|
||||
expectExtraDNSLabelsMismatch bool
|
||||
extraDNSLabels []string
|
||||
expectLoginError bool
|
||||
@@ -1968,7 +1973,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2383,186 +2388,3 @@ func TestBufferUpdateAccountPeers(t *testing.T) {
|
||||
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
|
||||
}
|
||||
|
||||
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add peer with pending approval user
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Name: "test-peer",
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-peer",
|
||||
OS: "linux",
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
|
||||
}
|
||||
|
||||
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create regular user (not pending approval)
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add peer with regular user
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Name: "test-peer",
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-peer",
|
||||
OS: "linux",
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer)
|
||||
require.NoError(t, err, "Regular user should be able to add peers")
|
||||
}
|
||||
|
||||
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a peer using AddPeer method for the pending user (simulate existing peer)
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set the user to not be pending initially so peer can be added
|
||||
pendingUser.Blocked = false
|
||||
pendingUser.PendingApproval = false
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add peer using regular flow
|
||||
newPeer := &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Name: "test-peer",
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-peer",
|
||||
OS: "linux",
|
||||
WtVersion: "0.28.0",
|
||||
},
|
||||
}
|
||||
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now set the user back to pending approval after peer was created
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to login with pending approval user
|
||||
login := types.PeerLogin{
|
||||
WireGuardPubKey: existingPeer.Key,
|
||||
UserID: pendingUser.Id,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-peer",
|
||||
OS: "linux",
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
require.Error(t, err)
|
||||
e, ok := status.FromError(err)
|
||||
require.True(t, ok, "error is not a gRPC status error")
|
||||
assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code")
|
||||
}
|
||||
|
||||
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create regular user (not pending approval)
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add peer using regular flow for the regular user
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
newPeer := &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Name: "test-peer",
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-peer",
|
||||
OS: "linux",
|
||||
WtVersion: "0.28.0",
|
||||
},
|
||||
}
|
||||
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to login with regular user
|
||||
login := types.PeerLogin{
|
||||
WireGuardPubKey: existingPeer.Key,
|
||||
UserID: regularUser.Id,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
Hostname: "test-peer",
|
||||
OS: "linux",
|
||||
},
|
||||
}
|
||||
|
||||
_, _, _, err = manager.LoginPeer(context.Background(), login)
|
||||
require.NoError(t, err, "Regular user should be able to login peers")
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
|
||||
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
|
||||
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -62,7 +61,3 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
|
||||
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
@@ -79,18 +79,3 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs)
|
||||
ret0, _ := ret[0].([]*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs.
|
||||
func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
@@ -54,14 +54,10 @@ func (m *managerImpl) ValidateUserPermissions(
|
||||
return false, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
|
||||
if user.IsBlocked() && !user.PendingApproval {
|
||||
if user.IsBlocked() {
|
||||
return false, status.NewUserBlockedError()
|
||||
}
|
||||
|
||||
if user.IsBlocked() && user.PendingApproval {
|
||||
return false, status.NewUserPendingApprovalError()
|
||||
}
|
||||
|
||||
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -32,6 +32,9 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
operation := operations.Create
|
||||
if !create {
|
||||
operation = operations.Update
|
||||
@@ -58,17 +61,17 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
|
||||
if err = saveFunc(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return saveFunc(ctx, policy)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -85,6 +88,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -107,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.DeletePolicy(ctx, accountID, policyID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -167,22 +173,10 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
if policy.ID != "" {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Refactor to support multiple rules per policy
|
||||
existingRuleIDs := make(map[string]bool)
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
existingRuleIDs[rule.ID] = true
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
policy.ID = xid.New().String()
|
||||
policy.AccountID = accountID
|
||||
|
||||
@@ -32,6 +32,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
operation := operations.Create
|
||||
if !create {
|
||||
operation = operations.Update
|
||||
@@ -59,19 +62,15 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
if err = transaction.SavePostureChecks(ctx, postureChecks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
return transaction.SavePostureChecks(ctx, postureChecks)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -88,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
|
||||
// DeletePostureChecks deletes a posture check by ID.
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -108,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeletePostureChecks(ctx, accountID, postureChecksID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.DeletePostureChecks(ctx, accountID, postureChecksID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -134,7 +134,10 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
|
||||
}
|
||||
|
||||
// CreateRoute creates and saves a new route
|
||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) {
|
||||
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -167,7 +170,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
Enabled: enabled,
|
||||
Groups: groups,
|
||||
AccessControlGroups: accessControlGroupIDs,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}
|
||||
|
||||
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
|
||||
@@ -179,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.SaveRoute(ctx, newRoute)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -200,6 +202,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -233,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.SaveRoute(ctx, routeToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -254,6 +259,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -276,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
return transaction.DeleteRoute(ctx, accountID, string(routeID))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
|
||||
@@ -374,16 +382,15 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
return &proto.Route{
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
SkipAutoApply: route.SkipAutoApply,
|
||||
ID: string(route.ID),
|
||||
NetID: string(route.NetID),
|
||||
Network: route.Network.String(),
|
||||
Domains: route.Domains.ToPunycodeList(),
|
||||
NetworkType: int64(route.NetworkType),
|
||||
Peer: route.Peer,
|
||||
Metric: int64(route.Metric),
|
||||
Masquerade: route.Masquerade,
|
||||
KeepRoute: route.KeepRoute,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ func TestCreateRoute(t *testing.T) {
|
||||
enabled bool
|
||||
groups []string
|
||||
accessControlGroups []string
|
||||
skipAutoApply bool
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -445,13 +444,13 @@ func TestCreateRoute(t *testing.T) {
|
||||
if testCase.createInitRoute {
|
||||
groupAll, errInit := account.GetGroupAll()
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false, true)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false, true)
|
||||
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
|
||||
require.NoError(t, errInit)
|
||||
}
|
||||
|
||||
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute, testCase.inputArgs.skipAutoApply)
|
||||
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
@@ -1085,7 +1084,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply)
|
||||
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, newRoute.Enabled, true)
|
||||
|
||||
@@ -1177,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
|
||||
|
||||
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply)
|
||||
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
|
||||
require.NoError(t, err)
|
||||
|
||||
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
|
||||
@@ -2005,7 +2004,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2041,7 +2040,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
|
||||
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply,
|
||||
route.Groups, []string{}, true, userID, route.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2077,7 +2076,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
newRoute, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
|
||||
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
|
||||
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, !baseRoute.SkipAutoApply,
|
||||
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
baseRoute = *newRoute
|
||||
@@ -2143,7 +2142,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -2183,7 +2182,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
_, err := manager.CreateRoute(
|
||||
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
|
||||
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply,
|
||||
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@@ -55,6 +55,8 @@ type SetupKeyUpdateOperation struct {
|
||||
// and adds it to the specified account. A list of autoGroups IDs can be empty.
|
||||
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create)
|
||||
if err != nil {
|
||||
@@ -105,6 +107,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
|
||||
@@ -52,6 +52,7 @@ const (
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
resourceLocks sync.Map
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
@@ -218,6 +219,44 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// Deprecated: Full account operations are no longer supported
|
||||
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
|
||||
start := time.Now()
|
||||
@@ -989,7 +1028,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
|
||||
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
|
||||
var account types.Account
|
||||
result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account)
|
||||
result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account)
|
||||
if result.Error != nil {
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -1474,7 +1513,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
@@ -1489,7 +1528,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
|
||||
|
||||
// RemovePeerFromGroup removes a peer from a group
|
||||
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
err := s.db.
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
||||
|
||||
if err != nil {
|
||||
@@ -1502,7 +1541,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
||||
|
||||
if err != nil {
|
||||
@@ -2090,7 +2129,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
|
||||
return fmt.Errorf("delete policy rules: %w", err)
|
||||
}
|
||||
@@ -2781,7 +2820,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
|
||||
tx := s.db
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -2922,22 +2961,3 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return []*nbpeer.Peer{}, nil
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
|
||||
Select("DISTINCT peer_id").
|
||||
Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
|
||||
|
||||
result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
@@ -3607,113 +3607,3 @@ func intToIPv4(n uint32) net.IP {
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group1ID := "test-group-1"
|
||||
group2ID := "test-group-2"
|
||||
emptyGroupID := "empty-group"
|
||||
|
||||
peer1 := "cfefqs706sqkneg59g4g"
|
||||
peer2 := "cfeg6sf06sqkneg59g50"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedPeers []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve peers from single group with multiple peers",
|
||||
groupIDs: []string{group1ID},
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from single group with one peer",
|
||||
groupIDs: []string{group2ID},
|
||||
expectedPeers: []string{peer1},
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from multiple groups (with overlap)",
|
||||
groupIDs: []string{group1ID, group2ID},
|
||||
expectedPeers: []string{peer1, peer2}, // should deduplicate
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from existing 'All' group",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from empty group",
|
||||
groupIDs: []string{emptyGroupID},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from non-existing group",
|
||||
groupIDs: []string{"non-existing-group"},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mix of existing and non-existing groups",
|
||||
groupIDs: []string{group1ID, "non-existing-group"},
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
groups := []*types.Group{
|
||||
{
|
||||
ID: group1ID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
{
|
||||
ID: group2ID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID))
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID))
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID))
|
||||
|
||||
peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
|
||||
if tt.expectedCount > 0 {
|
||||
actualPeerIDs := make([]string, len(peers))
|
||||
for i, peer := range peers {
|
||||
actualPeerIDs[i] = peer.ID
|
||||
}
|
||||
assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs)
|
||||
|
||||
// Verify all returned peers belong to the correct account
|
||||
for _, peer := range peers {
|
||||
assert.Equal(t, accountID, peer.AccountID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,6 @@ type Store interface {
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
@@ -169,6 +168,10 @@ type Store interface {
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
|
||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock(ctx context.Context) func()
|
||||
|
||||
|
||||
@@ -4,28 +4,20 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
const AccountIDLabel = "account_id"
|
||||
const HighLatencyThreshold = time.Second * 7
|
||||
|
||||
// GRPCMetrics are gRPC server metrics
|
||||
type GRPCMetrics struct {
|
||||
meter metric.Meter
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
syncRequestsBlockedCounter metric.Int64Counter
|
||||
syncRequestHighLatencyCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
loginRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestHighLatencyCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
meter metric.Meter
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
|
||||
@@ -38,22 +30,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of sync gRPC requests from blocked peers"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
||||
@@ -62,22 +38,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of login gRPC requests from blocked peers"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"),
|
||||
@@ -123,19 +83,15 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
}
|
||||
|
||||
return &GRPCMetrics{
|
||||
meter: meter,
|
||||
syncRequestsCounter: syncRequestsCounter,
|
||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
|
||||
loginRequestsCounter: loginRequestsCounter,
|
||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||
activeStreamsGauge: activeStreamsGauge,
|
||||
syncRequestDuration: syncRequestDuration,
|
||||
loginRequestDuration: loginRequestDuration,
|
||||
channelQueueLength: channelQueue,
|
||||
ctx: ctx,
|
||||
meter: meter,
|
||||
syncRequestsCounter: syncRequestsCounter,
|
||||
loginRequestsCounter: loginRequestsCounter,
|
||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||
activeStreamsGauge: activeStreamsGauge,
|
||||
syncRequestDuration: syncRequestDuration,
|
||||
loginRequestDuration: loginRequestDuration,
|
||||
channelQueueLength: channelQueue,
|
||||
ctx: ctx,
|
||||
}, err
|
||||
}
|
||||
|
||||
@@ -144,11 +100,6 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() {
|
||||
grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers
|
||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() {
|
||||
grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API
|
||||
func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() {
|
||||
grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1)
|
||||
@@ -159,25 +110,14 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() {
|
||||
grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers
|
||||
func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
|
||||
grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountLoginRequestDuration counts the duration of the login gRPC requests
|
||||
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
|
||||
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) {
|
||||
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||
if duration > HighLatencyThreshold {
|
||||
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
||||
}
|
||||
}
|
||||
|
||||
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) {
|
||||
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||
if duration > HighLatencyThreshold {
|
||||
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
||||
|
||||
@@ -300,12 +300,9 @@ func (a *Account) GetPeerNetworkMap(
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
@@ -1654,24 +1651,3 @@ func peerSupportsPortRanges(peerVer string) bool {
|
||||
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
|
||||
return err == nil && meetMinVer
|
||||
}
|
||||
|
||||
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
|
||||
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
|
||||
peerIPs := make(map[string]struct{})
|
||||
|
||||
// Add peer's own IP to include its own DNS records
|
||||
peerIPs[peer.IP.String()] = struct{}{}
|
||||
|
||||
for _, peerToConnect := range peersToConnect {
|
||||
peerIPs[peerToConnect.IP.String()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range customZone.Records {
|
||||
if _, exists := peerIPs[record.RData]; exists {
|
||||
filteredRecords = append(filteredRecords, record)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredRecords
|
||||
}
|
||||
|
||||
@@ -2,17 +2,14 @@ package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -838,109 +835,3 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
}
|
||||
|
||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peer *nbpeer.Peer
|
||||
customZone nbdns.CustomZone
|
||||
peersToConnect []*nbpeer.Peer
|
||||
expectedRecords []nbdns.SimpleRecord
|
||||
}{
|
||||
{
|
||||
name: "empty peers to connect",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple peers multiple records match",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for i := 1; i <= 100; i++ {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
|
||||
})
|
||||
}
|
||||
return records
|
||||
}(),
|
||||
},
|
||||
peersToConnect: func() []*nbpeer.Peer {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
peers = append(peers, &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer%d", i),
|
||||
IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)),
|
||||
})
|
||||
}
|
||||
return peers
|
||||
}(),
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: func() []nbdns.SimpleRecord {
|
||||
var records []nbdns.SimpleRecord
|
||||
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
|
||||
records = append(records, nbdns.SimpleRecord{
|
||||
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 300,
|
||||
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
|
||||
})
|
||||
}
|
||||
return records
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "peers with multiple DNS labels",
|
||||
customZone: nbdns.CustomZone{
|
||||
Domain: "netbird.cloud.",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
|
||||
{Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
peersToConnect: []*nbpeer.Peer{
|
||||
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
|
||||
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
|
||||
},
|
||||
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
|
||||
expectedRecords: []nbdns.SimpleRecord{
|
||||
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
|
||||
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
|
||||
assert.Equal(t, len(tt.expectedRecords), len(result))
|
||||
assert.ElementsMatch(t, tt.expectedRecords, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -83,9 +83,6 @@ type ExtraSettings struct {
|
||||
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
|
||||
PeerApprovalEnabled bool
|
||||
|
||||
// UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator
|
||||
UserApprovalRequired bool
|
||||
|
||||
// IntegratedValidator is the string enum for the integrated validator type
|
||||
IntegratedValidator string
|
||||
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
|
||||
@@ -102,7 +99,6 @@ type ExtraSettings struct {
|
||||
func (e *ExtraSettings) Copy() *ExtraSettings {
|
||||
return &ExtraSettings{
|
||||
PeerApprovalEnabled: e.PeerApprovalEnabled,
|
||||
UserApprovalRequired: e.UserApprovalRequired,
|
||||
IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
|
||||
IntegratedValidator: e.IntegratedValidator,
|
||||
FlowEnabled: e.FlowEnabled,
|
||||
|
||||
@@ -64,7 +64,6 @@ type UserInfo struct {
|
||||
NonDeletable bool `json:"non_deletable"`
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
PendingApproval bool `json:"pending_approval"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
}
|
||||
|
||||
@@ -85,8 +84,6 @@ type User struct {
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// PendingApproval indicates whether the user requires approval before being activated
|
||||
PendingApproval bool
|
||||
// LastLogin is the last time the user logged in to IdP
|
||||
LastLogin *time.Time
|
||||
// CreatedAt records the time the user was created
|
||||
@@ -144,17 +141,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
|
||||
if userData == nil {
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
PendingApproval: u.PendingApproval,
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
}, nil
|
||||
}
|
||||
if userData.ID != u.Id {
|
||||
@@ -167,17 +163,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
PendingApproval: u.PendingApproval,
|
||||
ID: u.Id,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: autoGroups,
|
||||
Status: string(userStatus),
|
||||
IsServiceUser: u.IsServiceUser,
|
||||
IsBlocked: u.Blocked,
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -199,7 +194,6 @@ func (u *User) Copy() *User {
|
||||
ServiceUserName: u.ServiceUserName,
|
||||
PATs: pats,
|
||||
Blocked: u.Blocked,
|
||||
PendingApproval: u.PendingApproval,
|
||||
LastLogin: u.LastLogin,
|
||||
CreatedAt: u.CreatedAt,
|
||||
Issued: u.Issued,
|
||||
|
||||
@@ -26,6 +26,9 @@ import (
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -73,6 +76,9 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
@@ -221,6 +227,9 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -276,6 +285,9 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
@@ -316,6 +328,9 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if tokenName == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
||||
}
|
||||
@@ -364,6 +379,9 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
@@ -463,6 +481,9 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
|
||||
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
|
||||
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
|
||||
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -519,46 +540,33 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
initiatorUser = result
|
||||
}
|
||||
|
||||
var globalErr error
|
||||
for _, update := range updates {
|
||||
if update == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, update := range updates {
|
||||
if update == nil {
|
||||
return status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
|
||||
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
|
||||
}
|
||||
|
||||
if userHadPeers {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
err = transaction.SaveUser(ctx, updatedUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save updated user %s: %w", update.Id, err)
|
||||
}
|
||||
|
||||
usersToSave = append(usersToSave, updatedUser)
|
||||
addUserEvents = append(addUserEvents, userEvents...)
|
||||
peersToExpire = append(peersToExpire, userPeersToExpire...)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err)
|
||||
if len(updates) == 1 {
|
||||
return nil, err
|
||||
if userHadPeers {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue when updating multiple users
|
||||
}
|
||||
return transaction.SaveUsers(ctx, usersToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave))
|
||||
var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates))
|
||||
|
||||
userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID)
|
||||
if err != nil {
|
||||
@@ -591,7 +599,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return updatedUsersInfo, globalErr
|
||||
return updatedUsersInfo, nil
|
||||
}
|
||||
|
||||
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||
@@ -656,7 +664,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
}
|
||||
transferredOwnerRole = result
|
||||
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id)
|
||||
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
@@ -942,11 +950,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
|
||||
|
||||
if peer.UserID == "" {
|
||||
// we do not want to expire peers that are added via setup key
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
@@ -965,7 +968,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
|
||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
@@ -1213,77 +1215,3 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
|
||||
|
||||
return userWithPermissions, nil
|
||||
}
|
||||
|
||||
// ApproveUser approves a user that is pending approval
|
||||
func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.NewUserNotFoundError(targetUserID)
|
||||
}
|
||||
|
||||
if !user.PendingApproval {
|
||||
return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID)
|
||||
}
|
||||
|
||||
user.Blocked = false
|
||||
user.PendingApproval = false
|
||||
|
||||
err = am.Store.SaveUser(ctx, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil)
|
||||
|
||||
userInfo, err := am.getUserInfo(ctx, user, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
// RejectUser rejects a user that is pending approval by deleting them
|
||||
func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotFoundError(targetUserID)
|
||||
}
|
||||
|
||||
if !user.PendingApproval {
|
||||
return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID)
|
||||
}
|
||||
|
||||
err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1746,117 +1746,3 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
func TestApproveUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create account with admin and pending approval user
|
||||
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create admin user
|
||||
adminUser := types.NewAdminUser("admin-user")
|
||||
adminUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test successful approval
|
||||
approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, approvedUser.IsBlocked)
|
||||
assert.False(t, approvedUser.PendingApproval)
|
||||
|
||||
// Verify user is updated in store
|
||||
updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, updatedUser.Blocked)
|
||||
assert.False(t, updatedUser.PendingApproval)
|
||||
|
||||
// Test approval of non-pending user should fail
|
||||
_, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not pending approval")
|
||||
|
||||
// Test approval by non-admin should fail
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingUser2 := types.NewRegularUser("pending-user-2")
|
||||
pendingUser2.AccountID = account.Id
|
||||
pendingUser2.Blocked = true
|
||||
pendingUser2.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser2)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRejectUser(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create account with admin and pending approval user
|
||||
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create admin user
|
||||
adminUser := types.NewAdminUser("admin-user")
|
||||
adminUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), adminUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test successful rejection
|
||||
err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify user is deleted from store
|
||||
_, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id)
|
||||
require.Error(t, err)
|
||||
|
||||
// Test rejection of non-pending user should fail
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not pending approval")
|
||||
|
||||
// Test rejection by non-admin should fail
|
||||
pendingUser2 := types.NewRegularUser("pending-user-2")
|
||||
pendingUser2.AccountID = account.Id
|
||||
pendingUser2.Blocked = true
|
||||
pendingUser2.PendingApproval = true
|
||||
err = manager.Store.SaveUser(context.Background(), pendingUser2)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user