Revert "Merge branch 'main' into feature/remote-debug"

This reverts commit 6d6333058c, reversing
changes made to 446aded1f7.
This commit is contained in:
aliamerj
2025-10-06 12:24:48 +03:00
parent 6d6333058c
commit ba7793ae7b
288 changed files with 3117 additions and 8952 deletions

View File

@@ -111,6 +111,3 @@ Generate gRpc code:
#!/bin/bash
protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=.
```

View File

@@ -26,11 +26,7 @@ func (s *BaseServer) JobManager() *server.JobManager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(
context.Background(),
s.PeersManager(),
s.SettingsManager(),
s.EventStore())
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}

View File

@@ -105,8 +105,6 @@ type DefaultAccountManager struct {
accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
loginFilter *loginFilter
disableDefaultPolicy bool
}
@@ -216,7 +214,6 @@ func BuildManager(
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
}
@@ -303,6 +300,9 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account.
// Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
@@ -348,17 +348,13 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
}
if err = transaction.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
return err
}
if updateAccountPeers || groupsUpdated {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
}
return nil
return transaction.SaveAccountSettings(ctx, accountID, newSettings)
})
if err != nil {
return nil, err
@@ -502,6 +498,8 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
expiredPeers, err := am.getExpiredPeers(ctx, accountID)
if err != nil {
@@ -537,6 +535,9 @@ func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
inactivePeers, err := am.getInactivePeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
@@ -677,6 +678,8 @@ func (am *DefaultAccountManager) isCacheCold(ctx context.Context, store cacheSto
// DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner
func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
@@ -1045,6 +1048,9 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil
}
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
@@ -1137,20 +1143,12 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
}
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID)
if err != nil {
return "", err
}
if settings != nil && settings.Extra != nil && settings.Extra.UserApprovalRequired {
newUser.Blocked = true
newUser.PendingApproval = true
}
err = am.Store.SaveUser(ctx, newUser)
err := am.Store.SaveUser(ctx, newUser)
if err != nil {
return "", err
}
@@ -1160,11 +1158,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
return "", err
}
if newUser.PendingApproval {
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, map[string]any{"pending_approval": true})
} else {
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
}
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
return domainAccountID, nil
}
@@ -1363,6 +1357,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId)
defer func() {
if unlockAccount != nil {
unlockAccount()
}
}()
var addNewGroups []string
var removeOldGroups []string
var hasChanges bool
@@ -1425,6 +1426,8 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
unlockAccount()
unlockAccount = nil
return nil
})
@@ -1633,16 +1636,17 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
return am.loginFilter.allowLogin(wgPubKey, metahash)
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
}()
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
@@ -1653,18 +1657,22 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
metahash := metaHash(meta, realIP.String())
am.loginFilter.addLogin(peerPubKey, metahash)
return peer, netMap, postureChecks, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer accountUnlock()
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
return nil
}
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
@@ -1673,6 +1681,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return mapError(ctx, err)
@@ -1717,9 +1731,7 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account
log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err)
continue
}
if peer.UserID != "" {
peers = append(peers, peer)
}
peers = append(peers, peer)
}
if len(peers) > 0 {
err := am.expireAndUpdatePeers(ctx, accountID, peers)
@@ -1815,9 +1827,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
Extra: &types.ExtraSettings{
UserApprovalRequired: true,
},
},
Onboarding: types.AccountOnboarding{
OnboardingFlowPending: true,
@@ -1924,9 +1933,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: true,
Extra: &types.ExtraSettings{
UserApprovalRequired: true,
},
},
}
@@ -2112,6 +2118,9 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)

View File

@@ -32,8 +32,6 @@ type Manager interface {
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error)
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error)
@@ -79,7 +77,7 @@ type Manager interface {
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error)
CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
@@ -128,5 +126,4 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
AllowSync(string, uint64) bool
}

View File

@@ -15,7 +15,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -26,7 +25,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -3048,14 +3046,19 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "sync", "syncAndMark")
}
if msPerOp > maxExpected {
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
@@ -3118,14 +3121,19 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "existingPeer")
}
if msPerOp > maxExpected {
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
@@ -3188,44 +3196,24 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
}
if msPerOp > maxExpected {
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
}
func TestMain(m *testing.M) {
exitCode := m.Run()
if exitCode == 0 && os.Getenv("CI") == "true" {
runID := os.Getenv("GITHUB_RUN_ID")
storeEngine := os.Getenv("NETBIRD_STORE_ENGINE")
err := push.New("http://localhost:9091", "account_manager_benchmark").
Collector(testing_tools.BenchmarkDuration).
Grouping("ci_run", runID).
Grouping("store_engine", storeEngine).
Push()
if err != nil {
log.Printf("Failed to push metrics: %v", err)
} else {
time.Sleep(1 * time.Minute)
_ = push.New("http://localhost:9091", "account_manager_benchmark").
Grouping("ci_run", runID).
Grouping("store_engine", storeEngine).
Delete()
}
}
os.Exit(exitCode)
}
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
@@ -3606,93 +3594,3 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
require.Error(t, err, "should fail with invalid peer ID")
})
}
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create a domain-based account with user approval enabled
existingAccountID := "existing-account"
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false)
account.Settings.Extra = &types.ExtraSettings{
UserApprovalRequired: true,
}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Set the account as domain primary account
account.IsDomainPrimaryAccount = true
account.DomainCategory = types.PrivateCategory
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Test adding new user to existing account with approval required
newUserID := "new-user-id"
userAuth := nbcontext.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
}
acc, err := manager.Store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
require.True(t, acc.IsDomainPrimaryAccount, "Account should be primary for the domain")
require.Equal(t, "example.com", acc.Domain, "Account domain should match")
returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
require.NoError(t, err)
require.Equal(t, existingAccountID, returnedAccountID)
// Verify user was created with pending approval
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
require.NoError(t, err)
assert.True(t, user.Blocked, "User should be blocked when approval is required")
assert.True(t, user.PendingApproval, "User should be pending approval")
assert.Equal(t, existingAccountID, user.AccountID)
}
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create a domain-based account without user approval
ownerUserAuth := nbcontext.UserAuth{
UserId: "owner-user",
Domain: "example.com",
DomainCategory: types.PrivateCategory,
}
existingAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), ownerUserAuth)
require.NoError(t, err)
// Modify the account to disable user approval
account, err := manager.Store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
account.Settings.Extra = &types.ExtraSettings{
UserApprovalRequired: false,
}
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Test adding new user to existing account without approval required
newUserID := "new-user-id"
userAuth := nbcontext.UserAuth{
UserId: newUserID,
Domain: "example.com",
DomainCategory: types.PrivateCategory,
}
returnedAccountID, err := manager.getAccountIDWithAuthorizationClaims(context.Background(), userAuth)
require.NoError(t, err)
require.Equal(t, existingAccountID, returnedAccountID)
// Verify user was created without pending approval
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, newUserID)
require.NoError(t, err)
assert.False(t, user.Blocked, "User should not be blocked when approval is not required")
assert.False(t, user.PendingApproval, "User should not be pending approval")
assert.Equal(t, existingAccountID, user.AccountID)
}

View File

@@ -177,8 +177,6 @@ const (
AccountNetworkRangeUpdated Activity = 87
PeerIPUpdated Activity = 88
UserApproved Activity = 89
UserRejected Activity = 90
JobCreatedByUser Activity = 89
@@ -290,9 +288,6 @@ var activityMap = map[Activity]Code{
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
JobCreatedByUser: {"Create Job for peer", "peer.job.create"},
UserApproved: {"User approved", "user.approve"},
UserRejected: {"User rejected", "user.reject"},
}
// StringCode returns a string code of the activity

View File

@@ -5,7 +5,7 @@ import (
"net/url"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
nbcontext "github.com/netbirdio/netbird/management/server/context"

View File

@@ -17,7 +17,7 @@ import (
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
)
@@ -63,10 +63,12 @@ type Validator struct {
}
var (
errKeyNotFound = errors.New("unable to find appropriate key")
errTokenEmpty = errors.New("required authorization token not found")
errTokenInvalid = errors.New("token is invalid")
errTokenParsing = errors.New("token could not be parsed")
errKeyNotFound = errors.New("unable to find appropriate key")
errInvalidAudience = errors.New("invalid audience")
errInvalidIssuer = errors.New("invalid issuer")
errTokenEmpty = errors.New("required authorization token not found")
errTokenInvalid = errors.New("token is invalid")
errTokenParsing = errors.New("token could not be parsed")
)
func NewValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) *Validator {
@@ -86,6 +88,24 @@ func NewValidator(issuer string, audienceList []string, keysLocation string, idp
func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
// Verify 'aud' claim
var checkAud bool
for _, audience := range v.audienceList {
checkAud = token.Claims.(jwt.MapClaims).VerifyAudience(audience, false)
if checkAud {
break
}
}
if !checkAud {
return token, errInvalidAudience
}
// Verify 'issuer' claim
checkIss := token.Claims.(jwt.MapClaims).VerifyIssuer(v.issuer, false)
if !checkIss {
return token, errInvalidIssuer
}
// If keys are rotated, verify the keys prior to token validation
if v.idpSignkeyRefreshEnabled {
// If the keys are invalid, retrieve new ones
@@ -124,7 +144,7 @@ func (v *Validator) getKeyFunc(ctx context.Context) jwt.Keyfunc {
}
// ValidateAndParse validates the token and returns the parsed token
func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
func (m *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// If we get here, the required token is missing
@@ -133,13 +153,7 @@ func (v *Validator) ValidateAndParse(ctx context.Context, token string) (*jwt.To
}
// Now parse the token
parsedToken, err := jwt.Parse(
token,
v.getKeyFunc(ctx),
jwt.WithAudience(v.audienceList...),
jwt.WithIssuer(v.issuer),
jwt.WithIssuedAt(),
)
parsedToken, err := jwt.Parse(token, m.getKeyFunc(ctx))
// Check if there was an error in parsing...
if err != nil {

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"hash/crc32"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/base62"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"

View File

@@ -3,7 +3,7 @@ package auth
import (
"context"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"

View File

@@ -12,7 +12,7 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

View File

@@ -20,9 +20,29 @@ import (
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
@@ -93,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
})
if err != nil {
return err
@@ -192,8 +212,14 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
cacheKey := zone.Domain
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
} else {
protoZone := convertToProtoCustomZone(zone)
cache.SetCustomZone(cacheKey, protoZone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
}
for _, nsGroup := range update.NameServerGroups {

View File

@@ -474,6 +474,15 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
t.Errorf("Results should be different for different inputs")
}
// Verify that the cache contains elements from both configs
if _, exists := cache.GetCustomZone("example.com"); !exists {
t.Errorf("Cache should contain custom zone for example.com")
}
if _, exists := cache.GetCustomZone("example.org"); !exists {
t.Errorf("Cache should contain custom zone for example.org")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}

View File

@@ -67,6 +67,9 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -93,6 +96,10 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -102,8 +109,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return nil
})
if err != nil {
return err
@@ -122,6 +128,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
// UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -167,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.UpdateGroup(ctx, newGroup)
})
if err != nil {
return err
@@ -202,45 +211,35 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var groupsToSave []*types.Group
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
if err = transaction.CreateGroup(ctx, newGroup); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err
}
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groupIDs) == 1 {
return err
}
globalErr = errors.Join(globalErr, err)
// continue updating other groups
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.CreateGroups(ctx, accountID, groupsToSave)
})
if err != nil {
return err
}
@@ -253,7 +252,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
am.UpdateAccountPeers(ctx, accountID)
}
return globalErr
return nil
}
// UpdateGroups updates groups in the account.
@@ -270,45 +269,35 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var groupsToSave []*types.Group
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
}
newGroup.AccountID = accountID
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return err
}
groupsToSave = append(groupsToSave, newGroup)
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
groupIDs = append(groupIDs, newGroup.ID)
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groups) == 1 {
return err
}
globalErr = errors.Join(globalErr, err)
// continue updating other groups
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.UpdateGroups(ctx, accountID, groupsToSave)
})
if err != nil {
return err
}
@@ -321,7 +310,7 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
am.UpdateAccountPeers(ctx, accountID)
}
return globalErr
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
@@ -393,6 +382,8 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
}
@@ -432,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
deletedGroups = append(deletedGroups, group)
}
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
})
if err != nil {
return err
@@ -451,6 +442,9 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var updateAccountPeers bool
var err error
@@ -460,11 +454,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.AddPeerToGroup(ctx, accountID, peerID, groupID)
})
if err != nil {
return err
@@ -479,6 +473,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
@@ -498,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.UpdateGroup(ctx, group)
})
if err != nil {
return err
@@ -517,6 +514,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var updateAccountPeers bool
var err error
@@ -526,11 +526,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.RemovePeerFromGroup(ctx, peerID, groupID)
})
if err != nil {
return err
@@ -545,6 +545,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var group *types.Group
var updateAccountPeers bool
var err error
@@ -564,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.UpdateGroup(ctx, group)
})
if err != nil {
return err
@@ -604,6 +607,13 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
return nil
}

View File

@@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, newRoute.SkipAutoApply,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
require.NoError(t, err)

View File

@@ -7,7 +7,6 @@ import (
"io"
"net"
"net/netip"
"os"
"strings"
"sync"
"time"
@@ -41,30 +40,21 @@ import (
internalStatus "github.com/netbirdio/netbird/shared/management/status"
)
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
)
// GRPCServer an instance of a Management gRPC API server
type GRPCServer struct {
accountManager account.Manager
settingsManager settings.Manager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager
jobManager *JobManager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
authManager auth.Manager
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
peersUpdateManager *PeersUpdateManager
jobManager *JobManager
config *nbconfig.Config
secretsManager SecretsManager
appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager
peerLocks sync.Map
authManager auth.Manager
integratedPeerValidator integrated_validator.IntegratedValidator
}
// NewServer creates a new Management server
@@ -95,24 +85,19 @@ func NewServer(
}
}
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
jobManager: jobManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
authManager: authManager,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
peersUpdateManager: peersUpdateManager,
jobManager: jobManager,
accountManager: accountManager,
settingsManager: settingsManager,
config: config,
secretsManager: secretsManager,
authManager: authManager,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
integratedPeerValidator: integratedPeerValidator,
}, nil
}
@@ -192,6 +177,9 @@ func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error {
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
reqStart := time.Now()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
ctx := srv.Context()
@@ -200,27 +188,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if err != nil {
return err
}
realIP := getRealIP(ctx)
sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
}
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
@@ -244,12 +211,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
return mapError(ctx, err)
@@ -267,7 +236,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
unlock()
@@ -366,7 +335,6 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
@@ -505,9 +473,6 @@ func mapError(ctx context.Context, err error) error {
default:
}
}
if errors.Is(err, internalStatus.ErrPeerAlreadyLoggedIn) {
return status.Error(codes.PermissionDenied, internalStatus.ErrPeerAlreadyLoggedIn.Error())
}
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
@@ -599,9 +564,16 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now()
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
}
}()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
realIP := getRealIP(ctx)
sRealIP := realIP.String()
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, sRealIP)
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(ctx, req, loginReq)
@@ -609,24 +581,6 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, err
}
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestBlocked()
}
if s.blockPeersWithSameConfig {
return nil, internalStatus.ErrPeerAlreadyLoggedIn
}
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
@@ -637,12 +591,6 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
}()
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
@@ -663,7 +611,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: peerMeta,
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
@@ -1129,6 +1077,8 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, mapError(ctx, err)
}
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil

View File

@@ -11,11 +11,11 @@ import (
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@@ -198,7 +198,6 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.Extra != nil {
settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
@@ -328,7 +327,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
if settings.Extra != nil {
apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: settings.Extra.UserApprovalRequired,
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,

View File

@@ -15,11 +15,11 @@ import (
"github.com/stretchr/testify/assert"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
func initAccountsTestData(t *testing.T, account *types.Account) *handler {

View File

@@ -488,33 +488,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
}
return &api.PeerBatch{
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.GetLastLogin(),
LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
CreatedAt: peer.CreatedAt,
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: peer.Location.ConnectionIP.String(),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: peer.Meta.KernelVersion,
GeonameId: int(peer.Location.GeoNameID),
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: peer.UserID,
UiVersion: peer.Meta.UIVersion,
DnsLabel: fqdn(peer, dnsDomain),
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.GetLastLogin(),
LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
}
}

View File

@@ -8,19 +8,17 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/route"
)
const failedToConvertRoute = "failed to convert route to response: %v"
const exitNodeCIDR = "0.0.0.0/0"
// handler is the routes handler of the account
type handler struct {
accountManager account.Manager
@@ -126,16 +124,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
accessControlGroupIds = *req.AccessControlGroups
}
// Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes)
skipAutoApply := false
if req.SkipAutoApply != nil {
skipAutoApply = *req.SkipAutoApply
} else if newPrefix.String() == exitNodeCIDR {
skipAutoApply = false
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -152,31 +142,23 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
}
func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId)
}
func (h *handler) validateRouteUpdate(req api.PutApiRoutesRouteIdJSONRequestBody) error {
return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId)
}
func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *string, peerGroups *[]string, networkId string) error {
if network != nil && domains != nil {
if req.Network != nil && req.Domains != nil {
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
}
if network == nil && domains == nil {
if req.Network == nil && req.Domains == nil {
return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided")
}
if peer == nil && peerGroups == nil {
if req.Peer == nil && req.PeerGroups == nil {
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
}
if peer != nil && peerGroups != nil {
if req.Peer != nil && req.PeerGroups != nil {
return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided")
}
if utf8.RuneCountInString(networkId) > route.MaxNetIDChar || networkId == "" {
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters",
route.MaxNetIDChar)
}
@@ -213,7 +195,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.validateRouteUpdate(req); err != nil {
if err := h.validateRoute(req); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -223,24 +205,15 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
peerID = *req.Peer
}
// Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes)
skipAutoApply := false
if req.SkipAutoApply != nil {
skipAutoApply = *req.SkipAutoApply
} else if req.Network != nil && *req.Network == exitNodeCIDR {
skipAutoApply = false
}
newRoute := &route.Route{
ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId),
Masquerade: req.Masquerade,
Metric: req.Metric,
Description: req.Description,
Enabled: req.Enabled,
Groups: req.Groups,
KeepRoute: req.KeepRoute,
SkipAutoApply: skipAutoApply,
ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId),
Masquerade: req.Masquerade,
Metric: req.Metric,
Description: req.Description,
Enabled: req.Enabled,
Groups: req.Groups,
KeepRoute: req.KeepRoute,
}
if req.Domains != nil {
@@ -348,19 +321,18 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
}
network := serverRoute.Network.String()
route := &api.Route{
Id: string(serverRoute.ID),
Description: serverRoute.Description,
NetworkId: string(serverRoute.NetID),
Enabled: serverRoute.Enabled,
Peer: &serverRoute.Peer,
Network: &network,
Domains: &domains,
NetworkType: serverRoute.NetworkType.String(),
Masquerade: serverRoute.Masquerade,
Metric: serverRoute.Metric,
Groups: serverRoute.Groups,
KeepRoute: serverRoute.KeepRoute,
SkipAutoApply: &serverRoute.SkipAutoApply,
Id: string(serverRoute.ID),
Description: serverRoute.Description,
NetworkId: string(serverRoute.NetID),
Enabled: serverRoute.Enabled,
Peer: &serverRoute.Peer,
Network: &network,
Domains: &domains,
NetworkType: serverRoute.NetworkType.String(),
Masquerade: serverRoute.Masquerade,
Metric: serverRoute.Metric,
Groups: serverRoute.Groups,
KeepRoute: serverRoute.KeepRoute,
}
if len(serverRoute.PeerGroups) > 0 {

View File

@@ -15,13 +15,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -62,22 +62,21 @@ func initRoutesTestData() *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
switch routeID {
case existingRouteID:
if routeID == existingRouteID {
return baseExistingRoute, nil
case existingRouteID2:
}
if routeID == existingRouteID2 {
route := baseExistingRoute.Copy()
route.PeerGroups = []string{existingGroupID}
return route, nil
case existingRouteID3:
} else if routeID == existingRouteID3 {
route := baseExistingRoute.Copy()
route.Domains = domain.List{existingDomain}
return route, nil
default:
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool, skipAutoApply bool) (*route.Route, error) {
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
@@ -104,7 +103,6 @@ func initRoutesTestData() *handler {
Groups: groups,
KeepRoute: keepRoute,
AccessControlGroups: accessControlGroups,
SkipAutoApply: skipAutoApply,
}, nil
},
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
@@ -192,20 +190,19 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"],"skip_auto_apply":false}`, existingPeerID, existingGroupID))),
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("192.168.0.0/16"),
Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
SkipAutoApply: util.ToPtr(false),
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("192.168.0.0/16"),
Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
},
},
{
@@ -213,22 +210,21 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID))),
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "domainNet",
Network: util.ToPtr("invalid Prefix"),
KeepRoute: true,
Domains: &[]string{existingDomain},
Peer: &existingPeerID,
NetworkType: route.DomainNetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
SkipAutoApply: util.ToPtr(false),
Id: existingRouteID,
Description: "Post",
NetworkId: "domainNet",
Network: util.ToPtr("invalid Prefix"),
KeepRoute: true,
Domains: &[]string{existingDomain},
Peer: &existingPeerID,
NetworkType: route.DomainNetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
},
},
{
@@ -236,7 +232,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBuffer(
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"],\"skip_auto_apply\":false}", existingPeerID, existingGroupID, existingGroupID))),
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
@@ -250,7 +246,6 @@ func TestRoutesHandlers(t *testing.T) {
Enabled: false,
Groups: []string{existingGroupID},
AccessControlGroups: &[]string{existingGroupID},
SkipAutoApply: util.ToPtr(false),
},
},
{
@@ -341,63 +336,60 @@ func TestRoutesHandlers(t *testing.T) {
name: "Network PUT OK",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"is_selected\":true}", existingPeerID, existingGroupID)),
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("192.168.0.0/16"),
Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
SkipAutoApply: util.ToPtr(false),
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("192.168.0.0/16"),
Peer: &existingPeerID,
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
},
},
{
name: "Domains PUT OK",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID)),
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("invalid Prefix"),
Domains: &[]string{existingDomain},
Peer: &existingPeerID,
NetworkType: route.DomainNetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
KeepRoute: true,
SkipAutoApply: util.ToPtr(false),
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("invalid Prefix"),
Domains: &[]string{existingDomain},
Peer: &existingPeerID,
NetworkType: route.DomainNetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
KeepRoute: true,
},
},
{
name: "PUT OK when peer_groups provided",
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"],\"skip_auto_apply\":false}", existingGroupID, existingGroupID)),
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRoute: &api.Route{
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("192.168.0.0/16"),
Peer: &emptyString,
PeerGroups: &[]string{existingGroupID},
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
SkipAutoApply: util.ToPtr(false),
Id: existingRouteID,
Description: "Post",
NetworkId: "awesomeNet",
Network: util.ToPtr("192.168.0.0/16"),
Peer: &emptyString,
PeerGroups: &[]string{existingGroupID},
NetworkType: route.IPv4NetworkString,
Masquerade: false,
Enabled: false,
Groups: []string{existingGroupID},
},
},
{

View File

@@ -9,11 +9,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
@@ -31,8 +31,6 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
}
@@ -325,76 +323,17 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
}
isCurrent := user.ID == currenUserID
return &api.User{
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
AutoGroups: autoGroups,
Status: userStatus,
IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked,
LastLogin: &user.LastLogin,
Issued: &user.Issued,
PendingApproval: user.PendingApproval,
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
AutoGroups: autoGroups,
Status: userStatus,
IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked,
LastLogin: &user.LastLogin,
Issued: &user.Issued,
}
}
// approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
userResponse := toUserResponse(user, userAuth.UserId)
util.WriteJSONObject(r.Context(), w, userResponse)
}
// rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -16,13 +16,13 @@ import (
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
@@ -725,133 +725,3 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri
}
return modules
}
func TestApproveUserEndpoint(t *testing.T) {
adminUser := &types.User{
Id: "admin-user",
Role: types.UserRoleAdmin,
AccountID: existingAccountID,
AutoGroups: []string{},
}
pendingUser := &types.User{
Id: "pending-user",
Role: types.UserRoleUser,
AccountID: existingAccountID,
Blocked: true,
PendingApproval: true,
AutoGroups: []string{},
}
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestingUser *types.User
}{
{
name: "approve user as admin should return 200",
expectedStatus: 200,
expectedBody: true,
requestingUser: adminUser,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{}
am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
approvedUserInfo := &types.UserInfo{
ID: pendingUser.Id,
Email: "pending@example.com",
Name: "Pending User",
Role: string(pendingUser.Role),
AutoGroups: []string{},
IsServiceUser: false,
IsBlocked: false,
PendingApproval: false,
LastLogin: time.Now(),
Issued: types.UserIssuedAPI,
}
return approvedUserInfo, nil
}
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
if tc.expectedBody {
var response api.User
err = json.Unmarshal(rr.Body.Bytes(), &response)
require.NoError(t, err)
assert.Equal(t, "pending-user", response.Id)
assert.False(t, response.IsBlocked)
assert.False(t, response.PendingApproval)
}
})
}
}
func TestRejectUserEndpoint(t *testing.T) {
adminUser := &types.User{
Id: "admin-user",
Role: types.UserRoleAdmin,
AccountID: existingAccountID,
AutoGroups: []string{},
}
tt := []struct {
name string
expectedStatus int
requestingUser *types.User
}{
{
name: "reject user as admin should return 200",
expectedStatus: 200,
requestingUser: adminUser,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
am := &mock_server.MockAccountManager{}
am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
return nil
}
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
require.NoError(t, err)
userAuth := nbcontext.UserAuth{
AccountId: existingAccountID,
UserId: tc.requestingUser.Id,
}
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
assert.Equal(t, tc.expectedStatus, rr.Code)
})
}
}

View File

@@ -13,9 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/types"
)
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)

View File

@@ -8,15 +8,16 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/auth"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
const (

View File

@@ -17,9 +17,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
const modulePeers = "peers"
@@ -48,7 +47,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -66,7 +65,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
})
}
}
@@ -83,7 +82,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -93,7 +92,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
})
}
}
@@ -110,7 +109,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -120,7 +119,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
})
}
}
@@ -137,7 +136,7 @@ func BenchmarkDeletePeer(b *testing.B) {
for name, bc := range benchCasesPeers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -147,7 +146,7 @@ func BenchmarkDeletePeer(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
})
}
}

View File

@@ -17,9 +17,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
// Map to store peers, groups, users, and setupKeys by name
@@ -48,7 +47,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -70,7 +69,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
})
}
}
@@ -87,7 +86,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -110,7 +109,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
})
}
}
@@ -127,7 +126,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -137,7 +136,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
})
}
}
@@ -154,7 +153,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
@@ -164,7 +163,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
})
}
}
@@ -181,7 +180,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
b.ResetTimer()
@@ -191,7 +190,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
})
}
}

View File

@@ -18,9 +18,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
const moduleUsers = "users"
@@ -47,7 +46,7 @@ func BenchmarkUpdateUser(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
@@ -72,7 +71,7 @@ func BenchmarkUpdateUser(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
})
}
}
@@ -85,18 +84,18 @@ func BenchmarkGetOneUser(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
})
}
}
@@ -111,18 +110,18 @@ func BenchmarkGetAllUsers(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
recorder := httptest.NewRecorder()
b.ResetTimer()
start := time.Now()
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
})
}
}
@@ -137,7 +136,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
for name, bc := range benchCasesUsers {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
recorder := httptest.NewRecorder()
@@ -148,7 +147,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
})
}
}

View File

@@ -15,10 +15,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func Test_SetupKeys_Create(t *testing.T) {
@@ -288,7 +287,7 @@ func Test_SetupKeys_Create(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
if err != nil {
@@ -573,7 +572,7 @@ func Test_SetupKeys_Update(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
body, err := json.Marshal(tc.requestBody)
if err != nil {
@@ -752,7 +751,7 @@ func Test_SetupKeys_Get(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
@@ -904,7 +903,7 @@ func Test_SetupKeys_GetAll(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId)
@@ -1088,7 +1087,7 @@ func Test_SetupKeys_Delete(t *testing.T) {
for _, tc := range tt {
for _, user := range users {
t.Run(tc.name, func(t *testing.T) {
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)

View File

@@ -1,137 +0,0 @@
package channel
import (
"context"
"errors"
"net/http"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/users"
)
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
done := make(chan struct{})
if validateUpdate {
go func() {
if expectedPeerUpdate != nil {
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
} else {
peerShouldNotReceiveUpdate(t, updMsg)
}
close(done)
}()
}
geoMock := &geolocation.Mock{}
validatorMock := server.MockIntegratedValidator{}
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
authManagerMock := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
MarkPATUsedFunc: authManager.MarkPATUsed,
GetPATInfoFunc: authManager.GetPATInfo,
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
t.Errorf("Unexpected message received: %+v", msg)
case <-time.After(500 * time.Millisecond):
return
}
}
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
assert.Equal(t, expected, msg)
case <-time.After(500 * time.Millisecond):
t.Errorf("Timed out waiting for update message")
}
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
userAuth.UserId = token
userAuth.AccountId = "testAccountId"
userAuth.Domain = "test.com"
userAuth.DomainCategory = "private"
case "otherUserId":
userAuth.UserId = "otherUserId"
userAuth.AccountId = "otherAccountId"
userAuth.Domain = "other.com"
userAuth.DomainCategory = "private"
case "invalidToken":
return userAuth, nil, errors.New("invalid token")
}
jwtToken := jwt.New(jwt.SigningMethodHS256)
return userAuth, jwtToken, nil
}

View File

@@ -3,6 +3,7 @@ package testing_tools
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
@@ -13,12 +14,32 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
@@ -202,11 +223,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta
return content, expectedStatus == http.StatusOK
}
func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) {
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
b.Helper()
ctx := context.Background()
acc, err := am.GetAccount(ctx, TestAccountId)
account, err := am.GetAccount(ctx, TestAccountId)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
@@ -222,23 +243,23 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true},
UserID: TestUserId,
}
acc.Peers[peer.ID] = peer
account.Peers[peer.ID] = peer
}
// Create users
for i := 0; i < users; i++ {
user := &types.User{
Id: fmt.Sprintf("olduser-%d", i),
AccountID: acc.Id,
AccountID: account.Id,
Role: types.UserRoleUser,
}
acc.Users[user.Id] = user
account.Users[user.Id] = user
}
for i := 0; i < setupKeys; i++ {
key := &types.SetupKey{
Id: fmt.Sprintf("oldkey-%d", i),
AccountID: acc.Id,
AccountID: account.Id,
AutoGroups: []string{"someGroupID"},
UpdatedAt: time.Now().UTC(),
ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)),
@@ -246,11 +267,11 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
Type: "reusable",
UsageLimit: 0,
}
acc.SetupKeys[key.Id] = key
account.SetupKeys[key.Id] = key
}
// Create groups and policies
acc.Policies = make([]*types.Policy, 0, groups)
account.Policies = make([]*types.Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{
@@ -261,7 +282,7 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
acc.Groups[groupID] = group
account.Groups[groupID] = group
// Create a policy for this group
policy := &types.Policy{
@@ -281,10 +302,10 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
},
},
}
acc.Policies = append(acc.Policies, policy)
account.Policies = append(account.Policies, policy)
}
acc.PostureChecks = []*posture.Checks{
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
@@ -296,38 +317,52 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
},
}
store := am.GetStore()
err = store.SaveAccount(context.Background(), acc)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
b.Fatalf("Failed to save account: %v", err)
}
}
func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
b.Helper()
if recorder.Code != http.StatusOK {
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
}
EvaluateBenchmarkResults(b, testCase, duration, module, operation)
}
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) {
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
b.Helper()
branch := os.Getenv("GIT_BRANCH")
if branch == "" && os.Getenv("CI") == "true" {
if branch == "" {
b.Fatalf("environment variable GIT_BRANCH is not set")
}
if recorder.Code != http.StatusOK {
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
}
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch)
gauge.Set(msPerOp)
b.ReportMetric(msPerOp, "ms/op")
}
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
userAuth := nbcontext.UserAuth{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
userAuth.UserId = token
userAuth.AccountId = "testAccountId"
userAuth.Domain = "test.com"
userAuth.DomainCategory = "private"
case "otherUserId":
userAuth.UserId = "otherUserId"
userAuth.AccountId = "otherAccountId"
userAuth.Domain = "other.com"
userAuth.DomainCategory = "private"
case "invalidToken":
return userAuth, nil, errors.New("invalid token")
}
jwtToken := jwt.New(jwt.SigningMethodHS256)
return userAuth, jwtToken, nil
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -17,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
)
@@ -231,7 +231,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}

View File

@@ -11,11 +11,12 @@ import (
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
)
type mockHTTPClient struct {

View File

@@ -2,7 +2,6 @@ package idp
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
@@ -12,6 +11,7 @@ import (
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
@@ -166,7 +166,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}

View File

@@ -2,7 +2,6 @@ package idp
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
@@ -11,6 +10,7 @@ import (
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -168,7 +168,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}

View File

@@ -2,7 +2,6 @@ package idp
import (
"context"
"encoding/base64"
"fmt"
"io"
"net/http"
@@ -12,6 +11,7 @@ import (
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -158,7 +158,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}

View File

@@ -2,7 +2,6 @@ package idp
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -13,6 +12,7 @@ import (
"sync"
"time"
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -253,7 +253,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
}
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1])
if err != nil {
return jwtToken, err
}

View File

@@ -46,6 +46,9 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context,
groups = []string{}
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {

View File

@@ -3,14 +3,12 @@ package port_forwarding
import (
"context"
"github.com/netbirdio/netbird/management/server/peer"
nbtypes "github.com/netbirdio/netbird/management/server/types"
)
type Controller interface {
SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer)
GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error)
GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error)
SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string)
GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error)
IsPeerInIngressPorts(ctx context.Context, accountID, peerID string) (bool, error)
}
@@ -21,15 +19,11 @@ func NewControllerMock() *ControllerMock {
return &ControllerMock{}
}
func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string, accountPeers map[string]*peer.Peer) {
func (c *ControllerMock) SendUpdate(ctx context.Context, accountID string, affectedProxyID string, affectedPeerIDs []string) {
// noop
}
func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID, peerID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) {
return make(map[string]*nbtypes.NetworkMap), nil
}
func (c *ControllerMock) GetProxyNetworkMapsAll(ctx context.Context, accountID string, accountPeers map[string]*peer.Peer) (map[string]*nbtypes.NetworkMap, error) {
func (c *ControllerMock) GetProxyNetworkMaps(ctx context.Context, accountID string) (map[string]*nbtypes.NetworkMap, error) {
return make(map[string]*nbtypes.NetworkMap), nil
}

View File

@@ -1,160 +0,0 @@
package server
import (
"hash/fnv"
"math"
"sync"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
type lfConfig struct {
reconnThreshold time.Duration
baseBlockDuration time.Duration
reconnLimitForBan int
metaChangeLimit int
}
func initCfg() *lfConfig {
return &lfConfig{
reconnThreshold: reconnThreshold,
baseBlockDuration: baseBlockDuration,
reconnLimitForBan: reconnLimitForBan,
metaChangeLimit: metaChangeLimit,
}
}
type loginFilter struct {
mu sync.RWMutex
cfg *lfConfig
logged map[string]*peerState
}
type peerState struct {
currentHash uint64
sessionCounter int
sessionStart time.Time
lastSeen time.Time
isBanned bool
banLevel int
banExpiresAt time.Time
metaChangeCounter int
metaChangeWindowStart time.Time
}
func newLoginFilter() *loginFilter {
return newLoginFilterWithCfg(initCfg())
}
func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter {
return &loginFilter{
logged: make(map[string]*peerState),
cfg: cfg,
}
}
func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool {
l.mu.RLock()
defer func() {
l.mu.RUnlock()
}()
state, ok := l.logged[wgPubKey]
if !ok {
return true
}
if state.isBanned && time.Now().Before(state.banExpiresAt) {
return false
}
if metaHash != state.currentHash {
if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit {
return false
}
}
return true
}
func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
now := time.Now()
l.mu.Lock()
defer func() {
l.mu.Unlock()
}()
state, ok := l.logged[wgPubKey]
if !ok {
l.logged[wgPubKey] = &peerState{
currentHash: metaHash,
sessionCounter: 1,
sessionStart: now,
lastSeen: now,
metaChangeWindowStart: now,
metaChangeCounter: 1,
}
return
}
if state.isBanned && now.After(state.banExpiresAt) {
state.isBanned = false
}
if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) {
state.banLevel = 0
}
if metaHash != state.currentHash {
if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) {
state.metaChangeWindowStart = now
state.metaChangeCounter = 1
} else {
state.metaChangeCounter++
}
state.currentHash = metaHash
state.sessionCounter = 1
state.sessionStart = now
state.lastSeen = now
return
}
state.sessionCounter++
if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold {
state.isBanned = true
state.banLevel++
backoffFactor := math.Pow(2, float64(state.banLevel-1))
duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor)
state.banExpiresAt = now.Add(duration)
state.sessionCounter = 0
state.sessionStart = now
}
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
}

View File

@@ -1,275 +0,0 @@
package server
import (
"hash/fnv"
"math"
"math/rand"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func testAdvancedCfg() *lfConfig {
return &lfConfig{
reconnThreshold: 50 * time.Millisecond,
baseBlockDuration: 100 * time.Millisecond,
reconnLimitForBan: 3,
metaChangeLimit: 2,
}
}
type LoginFilterTestSuite struct {
suite.Suite
filter *loginFilter
}
func (s *LoginFilterTestSuite) SetupTest() {
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
}
func TestLoginFilterTestSuite(t *testing.T) {
suite.Run(t, new(LoginFilterTestSuite))
}
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
}
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.False(s.filter.allowLogin(pubKey, meta))
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
baseBan := s.filter.cfg.baseBlockDuration
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(1, s.filter.logged[pubKey].banLevel)
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
s.filter.logged[pubKey].isBanned = false
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
isBanned: true,
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
}
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.False(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
currentHash: meta,
banLevel: 3,
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
}
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(0, s.filter.logged[pubKey].banLevel)
}
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
pubKey := "PUB_KEY_A"
limit := s.filter.cfg.metaChangeLimit
for i := range limit {
s.filter.addLogin(pubKey, uint64(i+1))
}
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
s.False(isAllowed, "should block new meta hash after limit is reached")
}
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
pubKey := "PUB_KEY_A"
meta1 := uint64(1)
meta2 := uint64(2)
meta3 := uint64(3)
s.filter.addLogin(pubKey, meta1)
s.filter.addLogin(pubKey, meta2)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
s.filter.addLogin(pubKey, meta3)
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
}
func BenchmarkHashingMethods(b *testing.B) {
meta := nbpeer.PeerSystemMeta{
WtVersion: "1.25.1",
OSVersion: "Ubuntu 22.04.3 LTS",
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
b.Run("BuilderString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
}
})
b.Run("FnvHashToString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
}
})
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
}
})
_ = resultString
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000
pubKeys := make([]string, numKeys)
for i := range numKeys {
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := pubKeys[r.Intn(numKeys)]
meta := r.Uint64()
if filter.allowLogin(key, meta) {
filter.addLogin(key, meta)
}
}
})
}

View File

@@ -61,7 +61,7 @@ type MockAccountManager struct {
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIPFunc func(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -95,8 +95,6 @@ type MockAccountManager struct {
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() account.ExternalCacheManager
@@ -518,9 +516,9 @@ func (am *MockAccountManager) UpdatePeerIP(ctx context.Context, accountID, userI
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool, isSelected bool) (*route.Route, error) {
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute, isSelected)
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
@@ -631,20 +629,6 @@ func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string,
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
}
func (am *MockAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
if am.ApproveUserFunc != nil {
return am.ApproveUserFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method ApproveUser is not implemented")
}
func (am *MockAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
if am.RejectUserFunc != nil {
return am.RejectUserFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method RejectUser is not implemented")
}
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil {
@@ -993,10 +977,3 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
}
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
}
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
if am.AllowSyncFunc != nil {
return am.AllowSyncFunc(key, hash)
}
return true
}

View File

@@ -37,6 +37,9 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -70,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err
}
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.SaveNameServerGroup(ctx, newNSGroup)
})
if err != nil {
return nil, err
@@ -91,6 +94,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
@@ -121,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.SaveNameServerGroup(ctx, nsGroupToSave)
})
if err != nil {
return err
@@ -142,6 +148,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -164,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID)
})
if err != nil {
return err

View File

@@ -70,6 +70,9 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
network.ID = xid.New().String()
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
defer unlock()
err = m.store.SaveNetwork(ctx, network)
if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err)
@@ -101,6 +104,9 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
return nil, status.NewPermissionDeniedError()
}
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
defer unlock()
_, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
@@ -125,6 +131,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
return fmt.Errorf("failed to get network: %w", err)
}
unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
@@ -158,15 +167,15 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
return fmt.Errorf("failed to delete network: %w", err)
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta())
})
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta())
})
return nil
})
if err != nil {

View File

@@ -108,6 +108,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return nil, fmt.Errorf("failed to create new network resource: %w", err)
}
unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID)
defer unlock()
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
@@ -201,6 +204,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
resource.Domain = domain
resource.Prefix = prefix
unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID)
defer unlock()
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
@@ -309,6 +315,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var events []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)

View File

@@ -88,6 +88,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
return nil, status.NewPermissionDeniedError()
}
unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID)
defer unlock()
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
@@ -154,6 +157,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return nil, status.NewPermissionDeniedError()
}
unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID)
defer unlock()
var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
@@ -197,6 +203,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
return status.NewPermissionDeniedError()
}
unlock := m.store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var event func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)

View File

@@ -192,6 +192,9 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -458,6 +461,9 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID,
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -480,7 +486,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return err
}
@@ -494,6 +500,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
if err = transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
return fmt.Errorf("failed to remove peer from groups: %w", err)
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -543,7 +553,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
}
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
@@ -615,9 +625,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: user not found")
}
if user.PendingApproval {
return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers")
}
groupsToAdd = user.AutoGroups
opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser
@@ -728,6 +735,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
}
}()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil {
@@ -779,10 +793,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil
})
if err == nil {
unlock()
unlock = nil
break
}
if isUniqueConstraintError(err) {
unlock()
unlock = nil
log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err)
continue
}
@@ -941,6 +959,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
}
unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlockAccount()
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
defer func() {
if unlockPeer != nil {
unlockPeer()
}
}()
var peer *nbpeer.Peer
var updateRemotePeers bool
var isRequiresApproval bool
@@ -1021,6 +1048,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
unlockPeer()
unlockPeer = nil
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -1152,7 +1182,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, err
@@ -1325,7 +1355,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
@@ -1464,7 +1494,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
@@ -1652,7 +1682,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
}
dnsDomain := am.GetDNSDomain(settings)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}

View File

@@ -24,7 +24,7 @@ type Peer struct {
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string `gorm:"index"`
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string // uniqueness index per accountID (check migrations)

View File

@@ -26,7 +26,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -990,14 +989,19 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := bc.minMsPerOpLocal
maxExpected := bc.maxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = bc.minMsPerOpCICD
maxExpected = bc.maxMsPerOpCICD
testing_tools.EvaluateBenchmarkResults(b, bc.name, time.Since(start), "login", "newPeer")
}
if msPerOp > maxExpected {
b.Logf("Benchmark %s: too slow (%.2f ms/op, max %.2f ms/op)", bc.name, msPerOp, maxExpected)
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
}
if msPerOp > (maxExpected * 1.1) {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
}
})
}
@@ -1605,6 +1609,7 @@ func Test_LoginPeer(t *testing.T) {
testCases := []struct {
name string
setupKey string
wireGuardPubKey string
expectExtraDNSLabelsMismatch bool
extraDNSLabels []string
expectLoginError bool
@@ -1968,7 +1973,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
require.NoError(t, err)
@@ -2383,186 +2388,3 @@ func TestBufferUpdateAccountPeers(t *testing.T) {
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Try to add peer with pending approval user
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer)
require.Error(t, err)
assert.Contains(t, err.Error(), "user pending approval cannot add peers")
}
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create regular user (not pending approval)
regularUser := types.NewRegularUser("regular-user")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
// Try to add peer with regular user
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
peer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer)
require.NoError(t, err, "Regular user should be able to add peers")
}
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Create a peer using AddPeer method for the pending user (simulate existing peer)
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
// Set the user to not be pending initially so peer can be added
pendingUser.Blocked = false
pendingUser.PendingApproval = false
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Add peer using regular flow
newPeer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer)
require.NoError(t, err)
// Now set the user back to pending approval after peer was created
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Try to login with pending approval user
login := types.PeerLogin{
WireGuardPubKey: existingPeer.Key,
UserID: pendingUser.Id,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.LoginPeer(context.Background(), login)
require.Error(t, err)
e, ok := status.FromError(err)
require.True(t, ok, "error is not a gRPC status error")
assert.Equal(t, status.PermissionDenied, e.Type(), "expected PermissionDenied error code")
}
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create regular user (not pending approval)
regularUser := types.NewRegularUser("regular-user")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
// Add peer using regular flow for the regular user
key, err := wgtypes.GenerateKey()
require.NoError(t, err)
newPeer := &nbpeer.Peer{
Key: key.PublicKey().String(),
Name: "test-peer",
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
WtVersion: "0.28.0",
},
}
existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer)
require.NoError(t, err)
// Try to login with regular user
login := types.PeerLogin{
WireGuardPubKey: existingPeer.Key,
UserID: regularUser.Id,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-peer",
OS: "linux",
},
}
_, _, _, err = manager.LoginPeer(context.Background(), login)
require.NoError(t, err, "Regular user should be able to login peers")
}

View File

@@ -18,7 +18,6 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
}
type managerImpl struct {
@@ -62,7 +61,3 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}

View File

@@ -79,18 +79,3 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs)
ret0, _ := ret[0].([]*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs.
func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
}

View File

@@ -54,14 +54,10 @@ func (m *managerImpl) ValidateUserPermissions(
return false, status.NewUserNotFoundError(userID)
}
if user.IsBlocked() && !user.PendingApproval {
if user.IsBlocked() {
return false, status.NewUserBlockedError()
}
if user.IsBlocked() && user.PendingApproval {
return false, status.NewUserPendingApprovalError()
}
if err := m.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
return false, err
}

View File

@@ -32,6 +32,9 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
operation := operations.Create
if !create {
operation = operations.Update
@@ -58,17 +61,17 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
saveFunc := transaction.CreatePolicy
if isUpdate {
action = activity.PolicyUpdated
saveFunc = transaction.SavePolicy
}
if err = saveFunc(ctx, policy); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return saveFunc(ctx, policy)
})
if err != nil {
return nil, err
@@ -85,6 +88,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -107,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.DeletePolicy(ctx, accountID, policyID)
})
if err != nil {
return err
@@ -167,22 +173,10 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil {
return err
}
// TODO: Refactor to support multiple rules per policy
existingRuleIDs := make(map[string]bool)
for _, rule := range existingPolicy.Rules {
existingRuleIDs[rule.ID] = true
}
for _, rule := range policy.Rules {
if rule.ID != "" && !existingRuleIDs[rule.ID] {
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
}
}
} else {
policy.ID = xid.New().String()
policy.AccountID = accountID

View File

@@ -32,6 +32,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
operation := operations.Create
if !create {
operation = operations.Update
@@ -59,19 +62,15 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return err
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
action = activity.PostureCheckUpdated
}
postureChecks.AccountID = accountID
if err = transaction.SavePostureChecks(ctx, postureChecks); err != nil {
return err
}
if isUpdate {
return transaction.IncrementNetworkSerial(ctx, accountID)
}
return nil
return transaction.SavePostureChecks(ctx, postureChecks)
})
if err != nil {
return nil, err
@@ -88,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -108,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err
}
if err = transaction.DeletePostureChecks(ctx, accountID, postureChecksID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.DeletePostureChecks(ctx, accountID, postureChecksID)
})
if err != nil {
return err

View File

@@ -134,7 +134,10 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
}
// CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool, skipAutoApply bool) (*route.Route, error) {
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -167,7 +170,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
Enabled: enabled,
Groups: groups,
AccessControlGroups: accessControlGroupIDs,
SkipAutoApply: skipAutoApply,
}
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
@@ -179,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err
}
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.SaveRoute(ctx, newRoute)
})
if err != nil {
return nil, err
@@ -200,6 +202,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -233,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
routeToSave.AccountID = accountID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.SaveRoute(ctx, routeToSave)
})
if err != nil {
return err
@@ -254,6 +259,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -276,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return err
}
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
return transaction.DeleteRoute(ctx, accountID, string(routeID))
})
if err != nil {
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
@@ -374,16 +382,15 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID
func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
}
}

View File

@@ -69,7 +69,6 @@ func TestCreateRoute(t *testing.T) {
enabled bool
groups []string
accessControlGroups []string
skipAutoApply bool
}
testCases := []struct {
@@ -445,13 +444,13 @@ func TestCreateRoute(t *testing.T) {
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false, true)
_, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false, true)
_, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
}
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute, testCase.inputArgs.skipAutoApply)
outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err)
@@ -1085,7 +1084,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply)
newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)
@@ -1177,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute, baseRoute.SkipAutoApply)
createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -2005,7 +2004,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
require.NoError(t, err)
@@ -2041,7 +2040,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
_, err := manager.CreateRoute(
context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer,
route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric,
route.Groups, []string{}, true, userID, route.KeepRoute, route.SkipAutoApply,
route.Groups, []string{}, true, userID, route.KeepRoute,
)
require.NoError(t, err)
@@ -2077,7 +2076,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
newRoute, err := manager.CreateRoute(
context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, !baseRoute.SkipAutoApply,
baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute,
)
require.NoError(t, err)
baseRoute = *newRoute
@@ -2143,7 +2142,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
require.NoError(t, err)
@@ -2183,7 +2182,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
_, err := manager.CreateRoute(
context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer,
newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, !newRoute.SkipAutoApply,
newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute,
)
require.NoError(t, err)

View File

@@ -55,6 +55,8 @@ type SetupKeyUpdateOperation struct {
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Create)
if err != nil {
@@ -105,6 +107,9 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.SetupKeys, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)

View File

@@ -52,6 +52,7 @@ const (
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
db *gorm.DB
resourceLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
@@ -218,6 +219,44 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
startWait := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
startHold := time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
}
return unlock
}
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
startWait := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
startHold := time.Now()
unlock = func() {
mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
}
return unlock
}
// Deprecated: Full account operations are no longer supported
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
start := time.Now()
@@ -989,7 +1028,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
var account types.Account
result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account)
result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account)
if result.Error != nil {
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -1474,7 +1513,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
PeerID: peerID,
}
err := s.db.Clauses(clause.OnConflict{
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
@@ -1489,7 +1528,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
err := s.db.
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
@@ -1502,7 +1541,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
@@ -2090,7 +2129,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
}
func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
return s.db.Transaction(func(tx *gorm.DB) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
return fmt.Errorf("delete policy rules: %w", err)
}
@@ -2781,7 +2820,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
}
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
tx := s.db
tx := s.db.WithContext(ctx)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -2922,22 +2961,3 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
}
return nil
}
func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
if len(groupIDs) == 0 {
return []*nbpeer.Peer{}, nil
}
var peers []*nbpeer.Peer
peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT peer_id").
Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
}
return peers, nil
}

View File

@@ -3607,113 +3607,3 @@ func intToIPv4(n uint32) net.IP {
binary.BigEndian.PutUint32(ip, n)
return ip
}
func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group1ID := "test-group-1"
group2ID := "test-group-2"
emptyGroupID := "empty-group"
peer1 := "cfefqs706sqkneg59g4g"
peer2 := "cfeg6sf06sqkneg59g50"
tests := []struct {
name string
groupIDs []string
expectedPeers []string
expectedCount int
}{
{
name: "retrieve peers from single group with multiple peers",
groupIDs: []string{group1ID},
expectedPeers: []string{peer1, peer2},
expectedCount: 2,
},
{
name: "retrieve peers from single group with one peer",
groupIDs: []string{group2ID},
expectedPeers: []string{peer1},
expectedCount: 1,
},
{
name: "retrieve peers from multiple groups (with overlap)",
groupIDs: []string{group1ID, group2ID},
expectedPeers: []string{peer1, peer2}, // should deduplicate
expectedCount: 2,
},
{
name: "retrieve peers from existing 'All' group",
groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data
expectedPeers: []string{peer1, peer2},
expectedCount: 2,
},
{
name: "retrieve peers from empty group",
groupIDs: []string{emptyGroupID},
expectedPeers: []string{},
expectedCount: 0,
},
{
name: "retrieve peers from non-existing group",
groupIDs: []string{"non-existing-group"},
expectedPeers: []string{},
expectedCount: 0,
},
{
name: "empty group IDs list",
groupIDs: []string{},
expectedPeers: []string{},
expectedCount: 0,
},
{
name: "mix of existing and non-existing groups",
groupIDs: []string{group1ID, "non-existing-group"},
expectedPeers: []string{peer1, peer2},
expectedCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
groups := []*types.Group{
{
ID: group1ID,
AccountID: accountID,
},
{
ID: group2ID,
AccountID: accountID,
},
}
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID))
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID))
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID))
peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
if tt.expectedCount > 0 {
actualPeerIDs := make([]string, len(peers))
for i, peer := range peers {
actualPeerIDs[i] = peer.ID
}
assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs)
// Verify all returned peers belong to the correct account
for _, peer := range peers {
assert.Equal(t, accountID, peer.AccountID)
}
}
})
}
}

View File

@@ -136,7 +136,6 @@ type Store interface {
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
@@ -169,6 +168,10 @@ type Store interface {
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()

View File

@@ -4,28 +4,20 @@ import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
const AccountIDLabel = "account_id"
const HighLatencyThreshold = time.Second * 7
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
syncRequestHighLatencyCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
meter metric.Meter
syncRequestsCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
@@ -38,22 +30,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.sync.request.blocked.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from blocked peers"),
)
if err != nil {
return nil, err
}
syncRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.sync.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of sync gRPC requests from the peers that took longer than the threshold to establish a connection and receive network map updates (update channel)"),
)
if err != nil {
return nil, err
}
loginRequestsCounter, err := meter.Int64Counter("management.grpc.login.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -62,22 +38,6 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
loginRequestsBlockedCounter, err := meter.Int64Counter("management.grpc.login.request.blocked.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from blocked peers"),
)
if err != nil {
return nil, err
}
loginRequestHighLatencyCounter, err := meter.Int64Counter("management.grpc.login.request.high.latency.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of login gRPC requests from the peers that took longer than the threshold to authenticate and receive initial configuration and relay credentials"),
)
if err != nil {
return nil, err
}
getKeyRequestsCounter, err := meter.Int64Counter("management.grpc.key.request.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of key gRPC requests from the peers to get the server's public WireGuard key"),
@@ -123,19 +83,15 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
}
return &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
syncRequestHighLatencyCounter: syncRequestHighLatencyCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration,
loginRequestDuration: loginRequestDuration,
channelQueueLength: channelQueue,
ctx: ctx,
meter: meter,
syncRequestsCounter: syncRequestsCounter,
loginRequestsCounter: loginRequestsCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration,
loginRequestDuration: loginRequestDuration,
channelQueueLength: channelQueue,
ctx: ctx,
}, err
}
@@ -144,11 +100,6 @@ func (grpcMetrics *GRPCMetrics) CountSyncRequest() {
grpcMetrics.syncRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// CountSyncRequestBlocked counts the number of gRPC sync requests from blocked peers
func (grpcMetrics *GRPCMetrics) CountSyncRequestBlocked() {
grpcMetrics.syncRequestsBlockedCounter.Add(grpcMetrics.ctx, 1)
}
// CountGetKeyRequest counts the number of gRPC get server key requests coming to the gRPC API
func (grpcMetrics *GRPCMetrics) CountGetKeyRequest() {
grpcMetrics.getKeyRequestsCounter.Add(grpcMetrics.ctx, 1)
@@ -159,25 +110,14 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequest() {
grpcMetrics.loginRequestsCounter.Add(grpcMetrics.ctx, 1)
}
// CountLoginRequestBlocked counts the number of gRPC login requests from blocked peers
func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
grpcMetrics.loginRequestsBlockedCounter.Add(grpcMetrics.ctx, 1)
}
// CountLoginRequestDuration counts the duration of the login gRPC requests
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration) {
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
}
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
if duration > HighLatencyThreshold {
grpcMetrics.syncRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -300,12 +300,9 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
@@ -1654,24 +1651,3 @@ func peerSupportsPortRanges(peerVer string) bool {
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}
// filterZoneRecordsForPeers filters DNS records to only include peers to connect.
func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord {
filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records))
peerIPs := make(map[string]struct{})
// Add peer's own IP to include its own DNS records
peerIPs[peer.IP.String()] = struct{}{}
for _, peerToConnect := range peersToConnect {
peerIPs[peerToConnect.IP.String()] = struct{}{}
}
for _, record := range customZone.Records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
}
}
return filteredRecords
}

View File

@@ -2,17 +2,14 @@ package types
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -838,109 +835,3 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct {
name string
peer *nbpeer.Peer
customZone nbdns.CustomZone
peersToConnect []*nbpeer.Peer
expectedRecords []nbdns.SimpleRecord
}{
{
name: "empty peers to connect",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
{
name: "multiple peers multiple records match",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for i := 1; i <= 100; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
})
}
return records
}(),
},
peersToConnect: func() []*nbpeer.Peer {
var peers []*nbpeer.Peer
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
peers = append(peers, &nbpeer.Peer{
ID: fmt.Sprintf("peer%d", i),
IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)),
})
}
return peers
}(),
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: func() []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord
for _, i := range []int{1, 5, 10, 25, 50, 75, 100} {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("peer%d.netbird.cloud", i),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256),
})
}
return records
}(),
},
{
name: "peers with multiple DNS labels",
customZone: nbdns.CustomZone{
Domain: "netbird.cloud.",
Records: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
{Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
peersToConnect: []*nbpeer.Peer{
{ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}},
{ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}},
},
peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")},
expectedRecords: []nbdns.SimpleRecord{
{Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
{Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"},
{Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect)
assert.Equal(t, len(tt.expectedRecords), len(result))
assert.ElementsMatch(t, tt.expectedRecords, result)
})
}
}

View File

@@ -12,11 +12,11 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
const (

View File

@@ -83,9 +83,6 @@ type ExtraSettings struct {
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
PeerApprovalEnabled bool
// UserApprovalRequired enables or disables the need for users joining via domain matching to be approved by an administrator
UserApprovalRequired bool
// IntegratedValidator is the string enum for the integrated validator type
IntegratedValidator string
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
@@ -102,7 +99,6 @@ type ExtraSettings struct {
func (e *ExtraSettings) Copy() *ExtraSettings {
return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled,
UserApprovalRequired: e.UserApprovalRequired,
IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
IntegratedValidator: e.IntegratedValidator,
FlowEnabled: e.FlowEnabled,

View File

@@ -64,7 +64,6 @@ type UserInfo struct {
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
PendingApproval bool `json:"pending_approval"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
}
@@ -85,8 +84,6 @@ type User struct {
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// PendingApproval indicates whether the user requires approval before being activated
PendingApproval bool
// LastLogin is the last time the user logged in to IdP
LastLogin *time.Time
// CreatedAt records the time the user was created
@@ -144,17 +141,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
if userData == nil {
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
PendingApproval: u.PendingApproval,
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
}, nil
}
if userData.ID != u.Id {
@@ -167,17 +163,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
}
return &UserInfo{
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
PendingApproval: u.PendingApproval,
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
}, nil
}
@@ -199,7 +194,6 @@ func (u *User) Copy() *User {
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
PendingApproval: u.PendingApproval,
LastLogin: u.LastLogin,
CreatedAt: u.CreatedAt,
Issued: u.Issued,

View File

@@ -26,6 +26,9 @@ import (
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -73,6 +76,9 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
}
@@ -221,6 +227,9 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.Errorf(status.InvalidArgument, "self deletion is not allowed")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil {
return err
@@ -276,6 +285,9 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
}
@@ -316,6 +328,9 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if tokenName == "" {
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
}
@@ -364,6 +379,9 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -463,6 +481,9 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists)
if err != nil {
return nil, err
@@ -519,46 +540,33 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
initiatorUser = result
}
var globalErr error
for _, update := range updates {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, update := range updates {
if update == nil {
return status.Errorf(status.InvalidArgument, "provided user update is nil")
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
userHadPeers, updatedUser, userPeersToExpire, userEvents, err := am.processUserUpdate(
ctx, transaction, groupsMap, accountID, initiatorUserID, initiatorUser, update, addIfNotExists, settings,
)
if err != nil {
return fmt.Errorf("failed to process update for user %s: %w", update.Id, err)
}
if userHadPeers {
updateAccountPeers = true
}
err = transaction.SaveUser(ctx, updatedUser)
if err != nil {
return fmt.Errorf("failed to save updated user %s: %w", update.Id, err)
}
usersToSave = append(usersToSave, updatedUser)
addUserEvents = append(addUserEvents, userEvents...)
peersToExpire = append(peersToExpire, userPeersToExpire...)
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to save user %s: %s", update.Id, err)
if len(updates) == 1 {
return nil, err
if userHadPeers {
updateAccountPeers = true
}
globalErr = errors.Join(globalErr, err)
// continue when updating multiple users
}
return transaction.SaveUsers(ctx, usersToSave)
})
if err != nil {
return nil, err
}
var updatedUsersInfo = make([]*types.UserInfo, 0, len(usersToSave))
var updatedUsersInfo = make([]*types.UserInfo, 0, len(updates))
userInfos, err := am.GetUsersFromAccount(ctx, accountID, initiatorUserID)
if err != nil {
@@ -591,7 +599,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
am.UpdateAccountPeers(ctx, accountID)
}
return updatedUsersInfo, globalErr
return updatedUsersInfo, nil
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
@@ -656,7 +664,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
transferredOwnerRole = result
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, updatedUser.AccountID, update.Id)
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthUpdate, updatedUser.AccountID, update.Id)
if err != nil {
return false, nil, nil, nil, err
}
@@ -942,11 +950,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
if peer.UserID == "" {
// we do not want to expire peers that are added via setup key
continue
}
if peer.Status.LoginExpired {
continue
}
@@ -965,7 +968,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -1213,77 +1215,3 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return userWithPermissions, nil
}
// ApproveUser approves a user that is pending approval
func (am *DefaultAccountManager) ApproveUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.NewUserNotFoundError(targetUserID)
}
if !user.PendingApproval {
return nil, status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID)
}
user.Blocked = false
user.PendingApproval = false
err = am.Store.SaveUser(ctx, user)
if err != nil {
return nil, err
}
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserApproved, nil)
userInfo, err := am.getUserInfo(ctx, user, accountID)
if err != nil {
return nil, err
}
return userInfo, nil
}
// RejectUser rejects a user that is pending approval by deleting them
func (am *DefaultAccountManager) RejectUser(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Users, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotFoundError(targetUserID)
}
if !user.PendingApproval {
return status.Errorf(status.InvalidArgument, "user %s is not pending approval", targetUserID)
}
err = am.DeleteUser(ctx, accountID, initiatorUserID, targetUserID)
if err != nil {
return err
}
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.UserRejected, nil)
return nil
}

View File

@@ -1746,117 +1746,3 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
return permissions
}
func TestApproveUser(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account with admin and pending approval user
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create admin user
adminUser := types.NewAdminUser("admin-user")
adminUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), adminUser)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Test successful approval
approvedUser, err := manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
require.NoError(t, err)
assert.False(t, approvedUser.IsBlocked)
assert.False(t, approvedUser.PendingApproval)
// Verify user is updated in store
updatedUser, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id)
require.NoError(t, err)
assert.False(t, updatedUser.Blocked)
assert.False(t, updatedUser.PendingApproval)
// Test approval of non-pending user should fail
_, err = manager.ApproveUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
require.Error(t, err)
assert.Contains(t, err.Error(), "not pending approval")
// Test approval by non-admin should fail
regularUser := types.NewRegularUser("regular-user")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
pendingUser2 := types.NewRegularUser("pending-user-2")
pendingUser2.AccountID = account.Id
pendingUser2.Blocked = true
pendingUser2.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser2)
require.NoError(t, err)
_, err = manager.ApproveUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
require.Error(t, err)
}
func TestRejectUser(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
// Create account with admin and pending approval user
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create admin user
adminUser := types.NewAdminUser("admin-user")
adminUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), adminUser)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser)
require.NoError(t, err)
// Test successful rejection
err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, pendingUser.Id)
require.NoError(t, err)
// Verify user is deleted from store
_, err = manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, pendingUser.Id)
require.Error(t, err)
// Test rejection of non-pending user should fail
regularUser := types.NewRegularUser("regular-user")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
err = manager.RejectUser(context.Background(), account.Id, adminUser.Id, regularUser.Id)
require.Error(t, err)
assert.Contains(t, err.Error(), "not pending approval")
// Test rejection by non-admin should fail
pendingUser2 := types.NewRegularUser("pending-user-2")
pendingUser2.AccountID = account.Id
pendingUser2.Blocked = true
pendingUser2.PendingApproval = true
err = manager.Store.SaveUser(context.Background(), pendingUser2)
require.NoError(t, err)
err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
require.Error(t, err)
}