[management] move network map logic into new design (#4774)

This commit is contained in:
Pascal Fischer
2025-11-13 12:09:46 +01:00
committed by GitHub
parent c28275611b
commit cc97cffff1
62 changed files with 2568 additions and 1989 deletions

View File

@@ -11,10 +11,8 @@ import (
"reflect"
"regexp"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
cacheStore "github.com/eko/gocache/lib/v4/store"
@@ -26,6 +24,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -53,9 +52,6 @@ const (
peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
)
type userLoggedInOnce bool
@@ -71,7 +67,7 @@ type DefaultAccountManager struct {
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
networkMapController network_map.Controller
idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache
@@ -91,8 +87,7 @@ type DefaultAccountManager struct {
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler
peerInactivityExpiry Scheduler
@@ -106,19 +101,11 @@ type DefaultAccountManager struct {
permissionsManager permissions.Manager
accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
loginFilter *loginFilter
disableDefaultPolicy bool
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
}
var _ account.Manager = (*DefaultAccountManager)(nil)
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
@@ -185,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
func BuildManager(
ctx context.Context,
store store.Store,
peersUpdateManager *PeersUpdateManager,
networkMapController network_map.Controller,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
@@ -204,27 +190,14 @@ func BuildManager(
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
}()
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
am := &DefaultAccountManager{
Store: store,
geo: geo,
peersUpdateManager: peersUpdateManager,
networkMapController: networkMapController,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(),
@@ -235,15 +208,10 @@ func BuildManager(
proxyController: proxyController,
settingsManager: settingsManager,
permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
}
am.startWarmup(ctx)
am.networkMapController.StartWarmup(ctx)
accountsCounter, err := store.GetAccountsCounter(ctx)
if err != nil {
@@ -291,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
am.ephemeralManager = em
}
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
interval, err := strconv.Atoi(intervalStr)
if err != nil {
initialInterval = 1
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
} else {
initialInterval = int64(interval) * 10
go func() {
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
startupPeriod, err := strconv.Atoi(startupPeriodStr)
if err != nil {
startupPeriod = 1
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
}
time.Sleep(time.Duration(startupPeriod) * time.Second)
am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
}()
}
am.updateAccountPeersBufferInterval.Store(initialInterval)
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
}
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
return am.externalCacheManager
}
@@ -419,9 +361,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
go am.UpdateAccountPeers(ctx, accountID)
}
@@ -1504,10 +1443,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
return err
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
@@ -1667,14 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool {
return am.loginFilter.allowLogin(wgPubKey, metahash)
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
@@ -1682,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
metahash := metaHash(meta, realIP.String())
am.loginFilter.addLogin(peerPubKey, metahash)
return peer, netMap, postureChecks, nil
return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
@@ -1702,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err
}
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return mapError(ctx, err)
return err
}
return nil
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
}
// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
return am.peersUpdateManager.HasChannel(peerID)
}
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain)
}
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return am.dnsDomain
}
if settings.DNSDomain == "" {
return am.dnsDomain
}
return settings.DNSDomain
}
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
peers := []*nbpeer.Peer{}
log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID)
@@ -2159,8 +2065,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
}
return nil
}
@@ -2208,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
if err != nil {
return fmt.Errorf("get account settings: %w", err)
}
dnsDomain := am.GetDNSDomain(settings)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String()

View File

@@ -89,7 +89,6 @@ type Manager interface {
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain(settings *types.Settings) string
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
@@ -97,10 +96,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
@@ -110,7 +107,7 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@@ -127,6 +124,4 @@ type Manager interface {
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
}

View File

@@ -0,0 +1,11 @@
package account
import (
"context"
"github.com/netbirdio/netbird/management/server/types"
)
type RequestBuffer interface {
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
}

View File

@@ -22,6 +22,9 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
@@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) {
}
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
@@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain, false)
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
@@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
}
func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
}
func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
}
func TestAccountManager_GetAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
func TestAccountManager_DeleteAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
DomainCategory: types.PublicCategory,
}
am, err := createManager(b)
am, _, err := createManager(b)
if err != nil {
b.Fatal(err)
return
@@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User {
}
func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1155,7 +1158,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
@@ -1164,7 +1167,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
@@ -1190,8 +1193,8 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1215,7 +1218,7 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
@@ -1224,10 +1227,10 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
}
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
// Ensure that we do not receive an update message before the policy is deleted
time.Sleep(time.Second)
@@ -1258,7 +1261,7 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
@@ -1267,7 +1270,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
AccountID: account.Id,
@@ -1280,8 +1283,8 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
return
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1316,7 +1319,7 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
@@ -1325,7 +1328,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
}
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := types.Group{
ID: "groupA",
@@ -1354,8 +1357,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
// We need to sleep to wait for the buffer peer update
time.Sleep(300 * time.Millisecond)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
@@ -1378,7 +1384,7 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
}
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
@@ -1387,10 +1393,10 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
}
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -1457,7 +1463,7 @@ func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
}
func TestAccountManager_DeletePeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1538,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy
}
func TestGetUsersFromAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -1837,7 +1843,7 @@ func hasNilField(x interface{}) error {
}
func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1852,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1908,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1951,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -2013,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -2677,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
// create a new account
@@ -2919,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
// Fatalf(format string, args ...interface{})
// }
func createManager(t testing.TB) (*DefaultAccountManager, error) {
func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
store, err := createStore(t)
if err != nil {
return nil, err
return nil, nil, err
}
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
return nil, nil, err
}
ctrl := gomock.NewController(t)
@@ -2948,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, err
return nil, nil, err
}
return manager, nil
return manager, updateManager, nil
}
func createStore(t testing.TB) (store.Store, error) {
@@ -2982,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
}
}
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
t.Helper()
manager, err := createManager(t)
manager, updateManager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -3026,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
peer2 := getPeer(manager, setupKey)
peer3 := getPeer(manager, setupKey)
return manager, account, peer1, peer2, peer3
return manager, updateManager, account, peer1, peer2, peer3
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
@@ -3039,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag
}
}
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) {
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3077,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3086,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err)
}
@@ -3140,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3149,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -3210,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -3219,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -3282,7 +3289,7 @@ func TestMain(m *testing.M) {
}
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -3328,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
}
func Test_UpdateToPrimaryAccount(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -3358,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
}
func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
t.Run("memory cache", func(t *testing.T) {
@@ -3408,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
}
func TestPropagateUserGroupMemberships(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
@@ -3525,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3557,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
}
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3596,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
}
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -3663,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
}
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -3709,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
}
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}

View File

@@ -3,54 +3,23 @@ package server
import (
"context"
"slices"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
dnsForwarderPort = nbdns.ForwarderServerPort
oldForwarderPort = nbdns.ForwarderClientPort
)
const dnsForwarderPortMinVersion = "v0.59.0"
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
NameServerGroups sync.Map
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
@@ -117,9 +86,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -194,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return validateGroups(settings.DisabledManagementGroups, groups)
}
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return int64(oldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
// Check if all peers have the required version or newer
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return int64(oldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return int64(oldForwarderPort)
}
}
// All peers have the required version or newer
return int64(dnsForwarderPort)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}

View File

@@ -2,9 +2,7 @@ package server
import (
"context"
"fmt"
"net/netip"
"reflect"
"testing"
"time"
@@ -12,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock())
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createDNSStore(t *testing.T) (store.Store, error) {
@@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return am.Store.GetAccount(context.Background(), account.Id)
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list
peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
}
// Test with peers that have old versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.26.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
}
// Test with peers that have new versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(dnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
}
// Test with peers that have mixed versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
}
// Test with peers that have empty version
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
}
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "development",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result == int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
}
// Test with peers that have unknown version string
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "unknown",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
}
}
func TestDNSAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{
@@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates

View File

@@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
}
func TestDefaultAccountManager_GetEvents(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
return
}

View File

@@ -114,9 +114,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -185,9 +182,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -256,9 +250,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -327,9 +318,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -376,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil
}
dnsDomain := am.GetDNSDomain(settings)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peerID := range addedPeers {
peer, ok := peers[peerID]
@@ -493,9 +481,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -534,9 +519,6 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -565,9 +547,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -606,9 +585,6 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -37,7 +37,7 @@ const (
)
func TestDefaultAccountManager_CreateGroup(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
if err != nil {
t.Fatalf("failed to create account manager: %s", err)
}
@@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
}
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am)
@@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
}
func TestGroupAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Saving a group that is not linked to any resource should not update account peers
@@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) {
}
func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) {
}
func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP {
}
func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return

File diff suppressed because it is too large Load Diff

View File

@@ -1,39 +0,0 @@
package server
import (
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
a := am.holder.GetAccount(account.Id)
if a == nil {
am.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
am.holder.AddAccount(account)
}
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
return am.holder.GetAccount(accountID)
}
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
a := am.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
am.holder.AddAccount(account)
}

View File

@@ -13,6 +13,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
@@ -65,6 +66,7 @@ func NewAPIHandler(
permissionsManager permissions.Manager,
peersManager nbpeers.Manager,
settingsManager settings.Manager,
networkMapController network_map.Controller,
) (http.Handler, error) {
var rateLimitingConfig *middleware.RateLimiterConfig
@@ -120,7 +122,7 @@ func NewAPIHandler(
}
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -23,11 +24,12 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
accountManager account.Manager
accountManager account.Manager
networkMapController network_map.Controller
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
peersHandler := NewHandler(accountManager)
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
peersHandler := NewHandler(accountManager, networkMapController)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager account.Manager) *Handler {
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
return &Handler{
accountManager: accountManager,
accountManager: accountManager,
networkMapController: networkMapController,
}
}
@@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected
if !h.accountManager.HasConnectedChannel(peer.ID) {
if !h.networkMapController.IsConnected(peer.ID) {
peerToReturn.Status.Connected = false
}
}
@@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
@@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil {
@@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
@@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
dnsDomain := h.accountManager.GetDNSDomain(account.Settings)
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)

View File

@@ -14,12 +14,14 @@ import (
"time"
"github.com/gorilla/mux"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -36,7 +38,7 @@ const (
serviceUser = "service_user"
)
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
@@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
},
}
ctrl := gomock.NewController(t)
networkMapController := network_map.NewMockController(ctrl)
networkMapController.EXPECT().
GetDNSDomain(gomock.Any()).
Return("domain").
AnyTimes()
networkMapController.EXPECT().
IsConnected(noUpdateChannelTestPeerID).
Return(false).
AnyTimes()
networkMapController.EXPECT().
IsConnected(gomock.Any()).
Return(true).
AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return account.Settings, nil
},
},
networkMapController: networkMapController,
}
}
@@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) {
rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1)
p := initTestMetaData(t, peer, peer1)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) {
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
p := initTestMetaData(t, peer1, peer2, peer3)
tt := []struct {
name string
@@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
},
}
p := initTestMetaData(testPeer)
p := initTestMetaData(t, testPeer)
tt := []struct {
name string

View File

@@ -10,6 +10,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
@@ -31,7 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/users"
)
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
@@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
peersUpdateManager := update_channel.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
done := make(chan struct{})
if validateUpdate {
@@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock())
am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
@@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
@@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server
}
}
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) {
t.Helper()
select {

View File

@@ -1,160 +0,0 @@
package server
import (
"hash/fnv"
"math"
"sync"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
type lfConfig struct {
reconnThreshold time.Duration
baseBlockDuration time.Duration
reconnLimitForBan int
metaChangeLimit int
}
func initCfg() *lfConfig {
return &lfConfig{
reconnThreshold: reconnThreshold,
baseBlockDuration: baseBlockDuration,
reconnLimitForBan: reconnLimitForBan,
metaChangeLimit: metaChangeLimit,
}
}
type loginFilter struct {
mu sync.RWMutex
cfg *lfConfig
logged map[string]*peerState
}
type peerState struct {
currentHash uint64
sessionCounter int
sessionStart time.Time
lastSeen time.Time
isBanned bool
banLevel int
banExpiresAt time.Time
metaChangeCounter int
metaChangeWindowStart time.Time
}
func newLoginFilter() *loginFilter {
return newLoginFilterWithCfg(initCfg())
}
func newLoginFilterWithCfg(cfg *lfConfig) *loginFilter {
return &loginFilter{
logged: make(map[string]*peerState),
cfg: cfg,
}
}
func (l *loginFilter) allowLogin(wgPubKey string, metaHash uint64) bool {
l.mu.RLock()
defer func() {
l.mu.RUnlock()
}()
state, ok := l.logged[wgPubKey]
if !ok {
return true
}
if state.isBanned && time.Now().Before(state.banExpiresAt) {
return false
}
if metaHash != state.currentHash {
if time.Now().Before(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) && state.metaChangeCounter >= l.cfg.metaChangeLimit {
return false
}
}
return true
}
func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
now := time.Now()
l.mu.Lock()
defer func() {
l.mu.Unlock()
}()
state, ok := l.logged[wgPubKey]
if !ok {
l.logged[wgPubKey] = &peerState{
currentHash: metaHash,
sessionCounter: 1,
sessionStart: now,
lastSeen: now,
metaChangeWindowStart: now,
metaChangeCounter: 1,
}
return
}
if state.isBanned && now.After(state.banExpiresAt) {
state.isBanned = false
}
if state.banLevel > 0 && now.Sub(state.lastSeen) > (2*l.cfg.baseBlockDuration) {
state.banLevel = 0
}
if metaHash != state.currentHash {
if now.After(state.metaChangeWindowStart.Add(l.cfg.reconnThreshold)) {
state.metaChangeWindowStart = now
state.metaChangeCounter = 1
} else {
state.metaChangeCounter++
}
state.currentHash = metaHash
state.sessionCounter = 1
state.sessionStart = now
state.lastSeen = now
return
}
state.sessionCounter++
if state.sessionCounter > l.cfg.reconnLimitForBan && now.Sub(state.sessionStart) < l.cfg.reconnThreshold {
state.isBanned = true
state.banLevel++
backoffFactor := math.Pow(2, float64(state.banLevel-1))
duration := time.Duration(float64(l.cfg.baseBlockDuration) * backoffFactor)
state.banExpiresAt = now.Add(duration)
state.sessionCounter = 0
state.sessionStart = now
}
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
}

View File

@@ -1,275 +0,0 @@
package server
import (
"hash/fnv"
"math"
"math/rand"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func testAdvancedCfg() *lfConfig {
return &lfConfig{
reconnThreshold: 50 * time.Millisecond,
baseBlockDuration: 100 * time.Millisecond,
reconnLimitForBan: 3,
metaChangeLimit: 2,
}
}
type LoginFilterTestSuite struct {
suite.Suite
filter *loginFilter
}
func (s *LoginFilterTestSuite) SetupTest() {
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
}
func TestLoginFilterTestSuite(t *testing.T) {
suite.Run(t, new(LoginFilterTestSuite))
}
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
}
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.False(s.filter.allowLogin(pubKey, meta))
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
baseBan := s.filter.cfg.baseBlockDuration
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(1, s.filter.logged[pubKey].banLevel)
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
s.filter.logged[pubKey].isBanned = false
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
isBanned: true,
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
}
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.False(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
currentHash: meta,
banLevel: 3,
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
}
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(0, s.filter.logged[pubKey].banLevel)
}
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
pubKey := "PUB_KEY_A"
limit := s.filter.cfg.metaChangeLimit
for i := range limit {
s.filter.addLogin(pubKey, uint64(i+1))
}
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
s.False(isAllowed, "should block new meta hash after limit is reached")
}
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
pubKey := "PUB_KEY_A"
meta1 := uint64(1)
meta2 := uint64(2)
meta3 := uint64(3)
s.filter.addLogin(pubKey, meta1)
s.filter.addLogin(pubKey, meta2)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
s.filter.addLogin(pubKey, meta3)
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
}
func BenchmarkHashingMethods(b *testing.B) {
meta := nbpeer.PeerSystemMeta{
WtVersion: "1.25.1",
OSVersion: "Ubuntu 22.04.3 LTS",
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
b.Run("BuilderString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
}
})
b.Run("FnvHashToString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
}
})
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
}
})
_ = resultString
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000
pubKeys := make([]string, numKeys)
for i := range numKeys {
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := pubKeys[r.Intn(numKeys)]
meta := r.Uint64()
if filter.allowLogin(key, meta) {
filter.addLogin(key, meta)
}
}
})
}

View File

@@ -22,10 +22,14 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
return loginResp, nil
}
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}
func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
@@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
t.Fatal(err)
}
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
@@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := BuildManager(ctx, store, networkMapController, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
@@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController)
if err != nil {
return nil, nil, "", cleanup, err
}
@@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) {
peerLogin := types.PeerLogin{
WireGuardPubKey: key.String(),
SSHKey: "random",
Meta: extractPeerMeta(context.Background(), meta),
SetupKey: setupKey.Key,
ConnectionIP: net.IP{1, 1, 1, 1},
Meta: nbpeer.PeerSystemMeta{
Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(),
Kernel: meta.GetKernel(),
Platform: meta.GetPlatform(),
OS: meta.GetOS(),
OSVersion: meta.GetOSVersion(),
WtVersion: meta.GetNetbirdVersion(),
UIVersion: meta.GetUiVersion(),
KernelVersion: meta.GetKernelVersion(),
SystemSerialNumber: meta.GetSysSerialNumber(),
SystemProductName: meta.GetSysProductName(),
SystemManufacturer: meta.GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: meta.GetEnvironment().GetCloud(),
Platform: meta.GetEnvironment().GetPlatform(),
},
Flags: nbpeer.Flags{
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
DisableDNS: meta.GetFlags().GetDisableDNS(),
DisableFirewall: meta.GetFlags().GetDisableFirewall(),
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
BlockInbound: meta.GetFlags().GetBlockInbound(),
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
},
},
SetupKey: setupKey.Key,
ConnectionIP: net.IP{1, 1, 1, 1},
}
login := func() error {

View File

@@ -20,7 +20,10 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -176,7 +179,6 @@ func startServer(
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -199,13 +201,18 @@ func startServer(
AnyTimes()
permissionsManager := permissions.NewManager(str)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(
context.Background(),
str,
peersUpdateManager,
networkMapController,
nil,
"",
"netbird.selfhosted",
eventStore,
nil,
false,
@@ -220,18 +227,18 @@ func startServer(
}
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(
context.Background(),
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := nbgrpc.NewServer(
config,
accountManager,
settingsMockManager,
peersUpdateManager,
updateManager,
secretsManager,
nil,
&manager.EphemeralManager{},
nil,
server.MockIntegratedValidator{},
networkMapController,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)

View File

@@ -38,7 +38,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -94,7 +94,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
@@ -178,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
@@ -747,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, accountID)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface

View File

@@ -83,9 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -137,9 +134,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -183,9 +177,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -11,6 +11,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes()
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
}
func createNSStore(t *testing.T) (store.Store, error) {
@@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) {
}
func TestNameServerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup
@@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a nameserver group with a distribution group no peers should not update account peers

View File

@@ -1,80 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
am.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (am *DefaultAccountManager) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := am.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
}
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := am.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
am.updateAccountInHolder(account)
}
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if am.experimentalNetworkMap(accountId) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
am.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
_, ok := am.expNewNetworkMapAIDs[accountId]
return am.expNewNetworkMap || ok
}

View File

@@ -177,9 +177,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -157,9 +157,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -260,9 +257,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -337,9 +331,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -119,9 +119,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil
@@ -186,9 +183,6 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil
@@ -223,9 +217,6 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -8,8 +8,6 @@ import (
"net"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
@@ -23,7 +21,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -31,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -140,12 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.OnPeerUpdated(accountID, peer)
}
return nil
@@ -201,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
var peer *nbpeer.Peer
var settings *types.Settings
var peerGroupList []string
var requiresPeerUpdates bool
var peerLabelChanged bool
var sshChanged bool
var loginExpirationChanged bool
@@ -224,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err
}
dnsDomain = am.GetDNSDomain(settings)
dnsDomain = am.networkMapController.GetDNSDomain(settings)
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
if err != nil {
return err
}
@@ -319,15 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
if peerLabelChanged || requiresPeerUpdates {
am.UpdateAccountPeers(ctx, accountID)
} else if sshChanged {
am.UpdateAccountPeer(ctx, accountID, peer.ID)
}
am.networkMapController.OnPeerUpdated(accountID, peer)
return peer, nil
}
@@ -383,20 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if am.experimentalNetworkMap(accountID) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
}
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
}
return nil
@@ -404,47 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
}
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(peer.AccountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return networkMap, nil
return am.networkMapController.GetNetworkMap(ctx, peerID)
}
// GetPeerNetwork returns the Network for a given peer
@@ -703,27 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if am.experimentalNetworkMap(accountID) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
return p, nmap, pc, err
}
func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
@@ -738,7 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
@@ -748,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, 0, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -798,17 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.OnPeerUpdated(accountID, peer)
}
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -933,15 +857,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
startBuffer := time.Now()
am.BufferUpdateAccountPeers(ctx, accountID)
log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer))
am.networkMapController.OnPeerUpdated(accountID, peer)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
return p, nmap, pc, err
}
// getPeerPostureChecks returns the posture checks for the peer.
@@ -1033,68 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, nil
}
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
startPosture := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, err
}
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return peer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
@@ -1118,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
return fmt.Errorf("failed to get account settings: %w", err)
}
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings)))
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
return nil
}
@@ -1214,232 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// UpdateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
}
globalStart := time.Now()
hasPeersConnected := false
for _, peer := range account.Peers {
if am.peersUpdateManager.HasChannel(peer.ID) {
hasPeersConnected = true
break
}
}
if !hasPeersConnected {
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err)
return
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if am.experimentalNetworkMap(accountID) {
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
}
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
return
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
for _, peer := range account.Peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return
}
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
//
wg.Wait()
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
}
}
type bufferUpdate struct {
mu sync.Mutex
next *time.Timer
update atomic.Bool
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
am.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
am.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
}()
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
}
// UpdateAccountPeer updates a single peer that belongs to an account.
// Should be called when changes need to be synced to a specific peer only.
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
if !am.peersUpdateManager.HasChannel(peerId) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
return
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
return
}
peer := account.GetPeer(peerId)
if peer == nil {
log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId)
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
return
}
dnsCache := &DNSConfigCache{}
dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
postureChecks, err := am.getPeerPostureChecks(account, peerId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
return
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
}
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
return
}
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
_ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId)
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1594,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err != nil {
return nil, err
}
dnsDomain := am.GetDNSDomain(settings)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peer := range peers {
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
@@ -1635,24 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err
}
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
@@ -1661,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return peerDeletedEvents, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}
// validatePeerDelete checks if the peer can be deleted.
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error {
linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId)

View File

@@ -13,7 +13,6 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -25,10 +24,14 @@ import (
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
@@ -172,12 +175,12 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
}
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -249,7 +252,7 @@ func testGetNetworkMapGeneral(t *testing.T) {
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
// TODO: disable until we start use policy again
t.Skip()
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -426,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
func TestAccountManager_GetPeerNetwork(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -487,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
}
func TestDefaultAccountManager_GetPeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -674,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -742,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) {
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
manager, updateManager, err := createManager(b)
if err != nil {
return nil, "", "", err
return nil, nil, "", "", err
}
accountID := "test_account"
@@ -798,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
ips := account.GetTakenIPs()
peerIP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, "", "", err
return nil, nil, "", "", err
}
peerKey, _ := wgtypes.GeneratePrivateKey()
@@ -904,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, "", "", err
return nil, nil, "", "", err
}
return manager, accountID, regularUser, nil
return manager, updateManager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
@@ -928,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) {
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -968,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -980,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
@@ -1013,7 +1012,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
@@ -1037,7 +1036,7 @@ func testUpdateAccountPeers(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err)
}
@@ -1049,13 +1048,12 @@ func testUpdateAccountPeers(t *testing.T) {
t.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
peerChannels := make(map[string]chan *network_map.UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
}
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels {
@@ -1097,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) {
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
turnRelayToken := &Token{
turnRelayToken := &grpc.Token{
Payload: "turn-user",
Signature: "turn-pass",
}
@@ -1177,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) {
},
},
}
dnsCache := &DNSConfigCache{}
dnsCache := &cache.DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
assert.NotNil(t, response)
// assert peer config
@@ -1289,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1369,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1517,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1566,7 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
@@ -1592,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes()
permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1725,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) {
}
func TestPeerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err)
@@ -1782,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
var peer5 *nbpeer.Peer
var peer6 *nbpeer.Peer
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
t.Skip("Currently all updates will trigger a network map")
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
@@ -1890,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
})
t.Run("validator requires no update", func(t *testing.T) {
t.Skip("Currently all updates will trigger a network map")
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
@@ -2091,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}
func Test_DeletePeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -2188,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
}
func Test_AddPeer(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -2276,136 +2297,8 @@ func Test_AddPeer(t *testing.T) {
assert.Equal(t, uint64(totalPeers), account.Network.Serial)
}
func TestBufferUpdateAccountPeers(t *testing.T) {
const (
peersCount = 1000
updateAccountInterval = 50 * time.Millisecond
)
var (
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
uapLastRun, dpLastRun atomic.Int64
totalNewRuns, totalOldRuns int
)
uap := func(ctx context.Context, accountID string) {
updatePeersDeleted.Store(deletedPeers.Load())
updatePeersRuns.Add(1)
uapLastRun.Store(time.Now().UnixMilli())
time.Sleep(100 * time.Millisecond)
}
t.Run("new approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
b := mu.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
uap(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
b.next = time.AfterFunc(updateAccountInterval, func() {
uap(ctx, accountID)
})
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalNewRuns = int(updatePeersRuns.Load())
})
t.Run("old approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
b := mu.(*sync.Mutex)
if !b.TryLock() {
return
}
go func() {
time.Sleep(updateAccountInterval)
b.Unlock()
uap(ctx, accountID)
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalOldRuns = int(updatePeersRuns.Load())
})
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2442,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
}
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2476,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
}
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -2541,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
}
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture"
@@ -77,9 +76,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -123,9 +119,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -258,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin
return validIDs
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result[i] = fwRule
}
return result
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}

View File

@@ -1135,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
}
func TestPolicyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -1164,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
var policyWithGroupRulesNoPeers *types.Policy

View File

@@ -2,19 +2,15 @@ package server
import (
"context"
"errors"
"fmt"
"slices"
"github.com/rs/xid"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -80,9 +76,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -139,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
return maps.Values(peerPostureChecks), nil
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
@@ -214,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
return nil
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false, nil
}
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)

View File

@@ -21,7 +21,7 @@ const (
)
func TestDefaultAccountManager_PostureCheck(t *testing.T) {
am, err := createManager(t)
am, _, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
}
func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{
{
@@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
postureCheckA := &posture.Checks{
@@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture check to policy where destination has peers but source does not
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
updateManager.CloseChannel(context.Background(), peer2.ID)
})
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
@@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestPostureChecksAccount(manager)

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -192,9 +191,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -249,9 +245,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -295,9 +288,6 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -381,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID
return groupsMap, nil
}
func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
// getPlaceholderIP returns a placeholder IP address for the route if domains are used
func getPlaceholderIP() netip.Prefix {
// Using an IP from the documentation range to minimize impact in case older clients try to set a route
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {

View File

@@ -14,6 +14,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -432,7 +434,7 @@ func TestCreateRoute(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createRouterManager(t)
am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -922,7 +924,7 @@ func TestSaveRoute(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createRouterManager(t)
am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1024,7 +1026,7 @@ func TestDeleteRoute(t *testing.T) {
Enabled: true,
}
am, err := createRouterManager(t)
am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1071,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
AccessControlGroups: []string{routeGroup1},
}
am, err := createRouterManager(t)
am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1163,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
AccessControlGroups: []string{routeGroup1},
}
am, err := createRouterManager(t)
am, _, err := createRouterManager(t)
if err != nil {
t.Error("failed to create account manager")
}
@@ -1250,11 +1252,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
}
func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
store, err := createRouterStore(t)
if err != nil {
return nil, err
return nil, nil, err
}
eventStore := &activity.InMemoryEventStore{}
@@ -1285,7 +1287,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
return am, updateManager, nil
}
func createRouterStore(t *testing.T) (store.Store, error) {
@@ -1948,7 +1959,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi
}
func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t)
manager, updateManager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager)
@@ -1976,9 +1987,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
require.NoError(t, err, "failed to create group %s", group.Name)
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
updateManager.CloseChannel(context.Background(), peer1ID)
})
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update

View File

@@ -18,7 +18,7 @@ import (
)
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
}
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
}
func TestGetSetupKeys(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) {
}
func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
var setupKey *types.SetupKey
@@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,270 +0,0 @@
package server
import (
"context"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
)
const defaultDuration = 12 * time.Hour
// SecretsManager used to manage TURN and relay secrets
type SecretsManager interface {
GenerateTurnToken() (*Token, error)
GenerateRelayToken() (*Token, error)
SetupRefresh(ctx context.Context, accountID, peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
turnCfg *nbconfig.TURNConfig
relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
}
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
relayCfg: relayCfg,
turnCancelMap: make(map[string]chan struct{}),
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
}
if turnCfg != nil {
duration := turnCfg.CredentialsTTL.Duration
if turnCfg.CredentialsTTL.Duration <= 0 {
log.Warnf("TURN credentials TTL is not set or invalid, using default value %s", defaultDuration)
duration = defaultDuration
}
mgr.turnHmacToken = auth.NewTimedHMAC(turnCfg.Secret, duration)
}
if relayCfg != nil {
duration := relayCfg.CredentialsTTL.Duration
if relayCfg.CredentialsTTL.Duration <= 0 {
log.Warnf("Relay credentials TTL is not set or invalid, using default value %s", defaultDuration)
duration = defaultDuration
}
hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
var err error
if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
log.Errorf("failed to create relay token generator: %s", err)
}
}
return mgr
}
// GenerateTurnToken generates new time-based secret credentials for TURN
func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
if m.turnHmacToken == nil {
return nil, fmt.Errorf("TURN configuration is not set")
}
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
if err != nil {
return nil, fmt.Errorf("generate TURN token: %s", err)
}
return (*Token)(turnToken), nil
}
// GenerateRelayToken generates new time-based secret credentials for relay
func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
if m.relayHmacToken == nil {
return nil, fmt.Errorf("relay configuration is not set")
}
relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil {
return nil, fmt.Errorf("generate relay token: %s", err)
}
return &Token{
Payload: string(relayToken.Payload),
Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
}, nil
}
func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
if channel, ok := m.turnCancelMap[peerID]; ok {
close(channel)
delete(m.turnCancelMap, peerID)
}
}
func (m *TimeBasedAuthSecretsManager) cancelRelay(peerID string) {
if channel, ok := m.relayCancelMap[peerID]; ok {
close(channel)
delete(m.relayCancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancelTURN(peerID)
m.cancelRelay(peerID)
}
// SetupRefresh starts peer credentials refresh
func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, accountID, peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancelTURN(peerID)
m.cancelRelay(peerID)
if m.turnCfg != nil && m.turnCfg.TimeBasedCredentials {
turnCancel := make(chan struct{}, 1)
m.turnCancelMap[peerID] = turnCancel
go m.refreshTURNTokens(ctx, accountID, peerID, turnCancel)
log.WithContext(ctx).Debugf("starting TURN refresh for %s", peerID)
}
if m.relayCfg != nil {
relayCancel := make(chan struct{}, 1)
m.relayCancelMap[peerID] = relayCancel
go m.refreshRelayTokens(ctx, accountID, peerID, relayCancel)
log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
}
}
func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, accountID, peerID string, cancel chan struct{}) {
ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
defer ticker.Stop()
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewTURNAndRelayTokens(ctx, accountID, peerID)
}
}
}
func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, accountID, peerID string, cancel chan struct{}) {
ticker := time.NewTicker(m.relayCfg.CredentialsTTL.Duration / 4 * 3)
defer ticker.Stop()
for {
select {
case <-cancel:
log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
return
case <-ticker.C:
m.pushNewRelayTokens(ctx, accountID, peerID)
}
}
}
func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Context, accountID, peerID string) {
turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
if err != nil {
log.WithContext(ctx).Errorf("failed to generate token for peer '%s': %s", peerID, err)
return
}
var turns []*proto.ProtectedHostConfig
for _, host := range m.turnCfg.Turns {
turn := &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: turnToken.Payload,
Password: turnToken.Signature,
}
turns = append(turns, turn)
}
update := &proto.SyncResponse{
NetbirdConfig: &proto.NetbirdConfig{
Turns: turns,
},
}
// workaround for the case when client is unable to handle turn and relay updates at different time
if m.relayCfg != nil {
token, err := m.GenerateRelayToken()
if err == nil {
update.NetbirdConfig.Relay = &proto.RelayConfig{
Urls: m.relayCfg.Addresses,
TokenPayload: token.Payload,
TokenSignature: token.Signature,
}
}
}
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
relayToken, err := m.relayHmacToken.GenerateToken()
if err != nil {
log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
return
}
update := &proto.SyncResponse{
NetbirdConfig: &proto.NetbirdConfig{
Relay: &proto.RelayConfig{
Urls: m.relayCfg.Addresses,
TokenPayload: string(relayToken.Payload),
TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
},
// omit Turns to avoid updates there
},
}
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig
}

View File

@@ -1,245 +0,0 @@
package server
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"hash"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &config.Host{
Proto: config.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
}
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
if turnCredentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty")
}
if turnCredentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty")
}
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
relayCredentials, err := tested.GenerateRelayToken()
require.NoError(t, err)
if relayCredentials.Payload == "" {
t.Errorf("expected generated relay payload not to be empty, got empty")
}
if relayCredentials.Signature == "" {
t.Errorf("expected generated relay signature not to be empty, got empty")
}
hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
}
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tested.SetupRefresh(ctx, "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
}
var updates []*UpdateMessage
loop:
for timeout := time.After(5 * time.Second); ; {
select {
case update := <-updateChannel:
updates = append(updates, update)
case <-timeout:
break loop
}
if len(updates) >= 2 {
break loop
}
}
if len(updates) < 2 {
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
}
var turnUpdates, relayUpdates int
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
for _, update := range updates {
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
turnUpdates++
if turnUpdates == 1 {
firstTurnUpdate = turns[0]
} else {
secondTurnUpdate = turns[0]
}
}
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
// avoid updating on turn updates since they also send relay credentials
if update.Update.GetNetbirdConfig().GetTurns() == nil {
relayUpdates++
if relayUpdates == 1 {
firstRelayUpdate = relay
} else {
secondRelayUpdate = relay
}
}
}
}
if turnUpdates < 1 {
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
}
if relayUpdates < 1 {
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
}
if firstTurnUpdate != nil && secondTurnUpdate != nil {
if firstTurnUpdate.Password == secondTurnUpdate.Password {
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
}
}
if firstRelayUpdate != nil && secondRelayUpdate != nil {
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
}
}
}
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in relay cancel map, got not present")
}
tested.CancelRefresh(peer)
if _, ok := tested.turnCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in turn cancel map, got present")
}
if _, ok := tested.relayCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in relay cancel map, got present")
}
}
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
t.Helper()
mac := hmac.New(algo, key)
_, err := mac.Write([]byte(username))
if err != nil {
t.Fatal(err)
}
expectedMAC := mac.Sum(nil)
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
if err != nil {
t.Fatal(err)
}
equal := hmac.Equal(decodedMAC, expectedMAC)
if !equal {
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
}
}

View File

@@ -1,176 +0,0 @@
package server
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/proto"
)
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
}
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
// SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
start := time.Now()
var found, dropped bool
p.channelsMux.RLock()
defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
}
}()
if channel, ok := p.peerChannels[peerID]; ok {
found = true
select {
case channel <- update:
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
}
} else {
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
}
}
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
start := time.Now()
closed := false
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed)
}
}()
if channel, ok := p.peerChannels[peerID]; ok {
closed = true
delete(p.peerChannels, peerID)
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
return channel
}
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
return
}
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
}
// CloseChannels closes updates channel for each given peer
func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) {
start := time.Now()
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs))
}
}()
for _, id := range peerIDs {
p.closeChannel(ctx, id)
}
}
// CloseChannel closes updates channel of a given peer
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
start := time.Now()
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start))
}
}()
p.closeChannel(ctx, peerID)
}
// GetAllConnectedPeers returns a copy of the connected peers map
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
start := time.Now()
p.channelsMux.RLock()
m := make(map[string]struct{})
defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
}
}()
for ID := range p.peerChannels {
m[ID] = struct{}{}
}
return m
}
// HasChannel returns true if peers has channel in update manager, otherwise false
func (p *PeersUpdateManager) HasChannel(peerID string) bool {
start := time.Now()
p.channelsMux.RLock()
defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start))
}
}()
_, ok := p.peerChannels[peerID]
return ok
}

View File

@@ -1,79 +0,0 @@
package server
import (
"context"
"testing"
"time"
"github.com/netbirdio/netbird/shared/management/proto"
)
// var peersUpdater *PeersUpdateManager
func TestCreateChannel(t *testing.T) {
peer := "test-create"
peersUpdater := NewPeersUpdateManager(nil)
defer peersUpdater.CloseChannel(context.Background(), peer)
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
}
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
update1 := &UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
}}
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
peersUpdater.SendUpdate(context.Background(), peer, update1)
select {
case <-peersUpdater.peerChannels[peer]:
default:
t.Error("Update wasn't send")
}
for range [channelBufferSize]int{} {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
update2 := &UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
}}
peersUpdater.SendUpdate(context.Background(), peer, update2)
timeout := time.After(5 * time.Second)
for range [channelBufferSize]int{} {
select {
case <-timeout:
t.Error("timed out reading previously sent updates")
case updateReader := <-peersUpdater.peerChannels[peer]:
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
t.Error("got the update that shouldn't have been sent")
}
}
}
}
func TestCloseChannel(t *testing.T) {
peer := "test-close"
peersUpdater := NewPeersUpdateManager(nil)
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
peersUpdater.CloseChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; ok {
t.Error("Error closing the channel")
}
}

View File

@@ -965,7 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if err != nil {
return err
}
dnsDomain := am.GetDNSDomain(settings)
dnsDomain := am.networkMapController.GetDNSDomain(settings)
var peerIDs []string
for _, peer := range peers {
@@ -992,16 +992,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
)
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
am.networkMapController.OnPeerUpdated(accountID, peer)
}
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.BufferUpdateAccountPeers(ctx, accountID)
am.networkMapController.DisconnectPeers(ctx, peerIDs)
}
return nil
}
@@ -1115,6 +1112,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var addPeerRemovedEvents []func()
var updateAccountPeers bool
var userPeers []*nbpeer.Peer
var targetUser *types.User
var err error
@@ -1124,7 +1122,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return fmt.Errorf("failed to get user to delete: %w", err)
}
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil {
return fmt.Errorf("failed to get user peers: %w", err)
}
@@ -1147,6 +1145,17 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return false, err
}
for _, peer := range userPeers {
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err)
}
}
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent()
}

View File

@@ -1161,7 +1161,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
}
func TestDefaultAccountManager_SaveUser(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
return
@@ -1333,7 +1333,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
func TestUserAccountPeersUpdate(t *testing.T) {
// account groups propagation is enabled
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA",
@@ -1357,9 +1357,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
updateManager.CloseChannel(context.Background(), peer1.ID)
})
// Creating a new regular user should not update account peers and not send peer update
@@ -1468,9 +1468,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
updateManager.CloseChannel(context.Background(), peer4.ID)
})
// deleting user with linked peers should update account peers and send peer update
@@ -1748,7 +1748,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
}
func TestApproveUser(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}
@@ -1807,7 +1807,7 @@ func TestApproveUser(t *testing.T) {
}
func TestRejectUser(t *testing.T) {
manager, err := createManager(t)
manager, _, err := createManager(t)
if err != nil {
t.Fatal(err)
}