[management] move network map logic into new design (#4774)

This commit is contained in:
Pascal Fischer
2025-11-13 12:09:46 +01:00
committed by GitHub
parent c28275611b
commit cc97cffff1
62 changed files with 2568 additions and 1989 deletions

View File

@@ -12,6 +12,9 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
@@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, nil return nil, nil
@@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -26,6 +26,9 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -1556,7 +1559,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -1584,13 +1586,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -14,6 +14,9 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

2
go.mod
View File

@@ -99,6 +99,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.35.0 go.opentelemetry.io/otel/metric v1.35.0
go.opentelemetry.io/otel/sdk/metric v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0
go.uber.org/mock v0.5.0
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
@@ -242,7 +243,6 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.uber.org/mock v0.5.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect golang.org/x/text v0.27.0 // indirect

View File

@@ -0,0 +1,31 @@
package cache
import (
"sync"
"github.com/netbirdio/netbird/shared/management/proto"
)
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
NameServerGroups sync.Map
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}

View File

@@ -0,0 +1,784 @@
package controller
import (
"context"
"errors"
"fmt"
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util"
)
type Controller struct {
repo Repository
metrics *metrics
// This should not be here, but we need to maintain it for the time being
accountManagerMetrics *telemetry.AccountManagerMetrics
peersUpdateManager network_map.PeersUpdateManager
settingsManager settings.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
requestBuffer account.RequestBuffer
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
}
type bufferUpdate struct {
mu sync.Mutex
next *time.Timer
update atomic.Bool
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller {
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
if err != nil {
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
accountManagerMetrics: metrics.AccountManagerMetrics(),
peersUpdateManager: peersUpdateManager,
requestBuffer: requestBuffer,
integratedPeerValidator: integratedPeerValidator,
settingsManager: settingsManager,
dnsDomain: dnsDomain,
proxyController: proxyController,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
}
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
}
globalStart := time.Now()
hasPeersConnected := false
for _, peer := range account.Peers {
if c.peersUpdateManager.HasChannel(peer.ID) {
hasPeersConnected = true
break
}
}
if !hasPeersConnected {
return nil
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err)
}
peer := account.GetPeer(peerId)
if peer == nil {
return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId)
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validated peers: %v", err)
}
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
postureChecks, err := c.getPeerPostureChecks(account, peerId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
if err != nil {
return fmt.Errorf("failed to get extra settings: %v", err)
}
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
return nil
}
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountId)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountId)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerId)
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
}
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, 0, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
}
func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return c.dnsDomain
}
if settings.DNSDomain == "" {
return c.dnsDomain
}
return settings.DNSDomain
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
return maps.Values(peerPostureChecks), nil
}
func (c *Controller) StartWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
interval, err := strconv.Atoi(intervalStr)
if err != nil {
initialInterval = 1
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
} else {
initialInterval = int64(interval) * 10
go func() {
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
startupPeriod, err := strconv.Atoi(startupPeriodStr)
if err != nil {
startupPeriod = 1
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
}
time.Sleep(time.Duration(startupPeriod) * time.Second)
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
}()
}
c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
}
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return int64(network_map.OldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
// Check if all peers have the required version or newer
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return int64(network_map.OldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return int64(network_map.OldForwarderPort)
}
}
// All peers have the required version or newer
return int64(network_map.DnsForwarderPort)
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false, nil
}
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
c.UpdatePeerInNetworkMapCache(accountId, peer)
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
}
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := c.repo.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return networkMap, nil
}
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) IsConnected(peerID string) bool {
return c.peersUpdateManager.HasChannel(peerID)
}

View File

@@ -0,0 +1,244 @@
package controller
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list
peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result)
}
// Test with peers that have old versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.26.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result)
}
// Test with peers that have new versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.DnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result)
}
// Test with peers that have mixed versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result)
}
// Test with peers that have empty version
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result)
}
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "development",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result == int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result)
}
// Test with peers that have unknown version string
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "unknown",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(network_map.OldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result)
}
}
func TestBufferUpdateAccountPeers(t *testing.T) {
const (
peersCount = 1000
updateAccountInterval = 50 * time.Millisecond
)
var (
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
uapLastRun, dpLastRun atomic.Int64
totalNewRuns, totalOldRuns int
)
uap := func(ctx context.Context, accountID string) {
updatePeersDeleted.Store(deletedPeers.Load())
updatePeersRuns.Add(1)
uapLastRun.Store(time.Now().UnixMilli())
time.Sleep(100 * time.Millisecond)
}
t.Run("new approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
b := mu.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
uap(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
b.next = time.AfterFunc(updateAccountInterval, func() {
uap(ctx, accountID)
})
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalNewRuns = int(updatePeersRuns.Load())
})
t.Run("old approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
b := mu.(*sync.Mutex)
if !b.TryLock() {
return
}
go func() {
time.Sleep(updateAccountInterval)
b.Unlock()
uap(ctx, accountID)
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalOldRuns = int(updatePeersRuns.Load())
})
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}

View File

@@ -0,0 +1,15 @@
package controller
import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
type metrics struct {
*telemetry.UpdateChannelMetrics
}
func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) {
return &metrics{
updateChannelMetrics,
}, nil
}

View File

@@ -0,0 +1,39 @@
package controller
import (
"context"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
type Repository interface {
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
}
type repository struct {
store store.Store
}
var _ Repository = (*repository)(nil)
func newRepository(s store.Store) Repository {
return &repository{
store: s,
}
}
func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) {
return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
}
func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) {
return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
}
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
return r.store.GetAccountByPeerID(ctx, peerID)
}

View File

@@ -0,0 +1,39 @@
package network_map
//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0"
)
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
DeletePeer(ctx context.Context, accountId string, peerId string) error
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
DisconnectPeers(ctx context.Context, peerIDs []string)
IsConnected(peerID string) bool
}

View File

@@ -0,0 +1,225 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
//
// Generated by this command:
//
// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//
// Package network_map is a generated GoMock package.
package network_map
import (
context "context"
reflect "reflect"
peer "github.com/netbirdio/netbird/management/server/peer"
posture "github.com/netbirdio/netbird/management/server/posture"
types "github.com/netbirdio/netbird/management/server/types"
gomock "go.uber.org/mock/gomock"
)
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
isgomock struct{}
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// BufferUpdateAccountPeers mocks base method.
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
}
// DeletePeer mocks base method.
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
return ret0
}
// DeletePeer indicates an expected call of DeletePeer.
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
}
// DisconnectPeers mocks base method.
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
}
// DisconnectPeers indicates an expected call of DisconnectPeers.
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
}
// GetDNSDomain mocks base method.
func (m *MockController) GetDNSDomain(settings *types.Settings) string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDNSDomain", settings)
ret0, _ := ret[0].(string)
return ret0
}
// GetDNSDomain indicates an expected call of GetDNSDomain.
func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings)
}
// GetNetworkMap mocks base method.
func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID)
ret0, _ := ret[0].(*types.NetworkMap)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetNetworkMap indicates an expected call of GetNetworkMap.
func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID)
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(int64)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// IsConnected mocks base method.
func (m *MockController) IsConnected(peerID string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsConnected", peerID)
ret0, _ := ret[0].(bool)
return ret0
}
// IsConnected indicates an expected call of IsConnected.
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
}
// OnPeerAdded mocks base method.
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerAdded indicates an expected call of OnPeerAdded.
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
}
// OnPeerDeleted mocks base method.
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
}
// OnPeerUpdated mocks base method.
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
}
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
}
// StartWarmup mocks base method.
func (m *MockController) StartWarmup(arg0 context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "StartWarmup", arg0)
}
// StartWarmup indicates an expected call of StartWarmup.
func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0)
}
// UpdateAccountPeer mocks base method.
func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAccountPeer indicates an expected call of UpdateAccountPeer.
func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId)
}
// UpdateAccountPeers mocks base method.
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
}

View File

@@ -0,0 +1 @@
package network_map

View File

@@ -0,0 +1,13 @@
package network_map
import "context"
type PeersUpdateManager interface {
SendUpdate(ctx context.Context, peerID string, update *UpdateMessage)
CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage
CloseChannel(ctx context.Context, peerID string)
CountStreams() int
HasChannel(peerID string) bool
CloseChannels(ctx context.Context, peerIDs []string)
GetAllConnectedPeers() map[string]struct{}
}

View File

@@ -1,4 +1,4 @@
package server package update_channel
import ( import (
"context" "context"
@@ -7,36 +7,34 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/proto"
) )
const channelBufferSize = 100 const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
}
type PeersUpdateManager struct { type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID // peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage peerChannels map[string]chan *network_map.UpdateMessage
// channelsMux keeps the mutex to access peerChannels // channelsMux keeps the mutex to access peerChannels
channelsMux *sync.RWMutex channelsMux *sync.RWMutex
// metrics provides method to collect application metrics // metrics provides method to collect application metrics
metrics telemetry.AppMetrics metrics telemetry.AppMetrics
} }
var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil)
// NewPeersUpdateManager returns a new instance of PeersUpdateManager // NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{ return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage), peerChannels: make(map[string]chan *network_map.UpdateMessage),
channelsMux: &sync.RWMutex{}, channelsMux: &sync.RWMutex{},
metrics: metrics, metrics: metrics,
} }
} }
// SendUpdate sends update message to the peer's channel // SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) {
start := time.Now() start := time.Now()
var found, dropped bool var found, dropped bool
@@ -64,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
} }
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage {
start := time.Now() start := time.Now()
closed := false closed := false
@@ -83,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
close(channel) close(channel)
} }
// mbragin: todo shouldn't it be more? or configurable? // mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize) channel := make(chan *network_map.UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel p.peerChannels[peerID] = channel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
@@ -174,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
return ok return ok
} }
func (p *PeersUpdateManager) CountStreams() int {
p.channelsMux.RLock()
defer p.channelsMux.RUnlock()
return len(p.peerChannels)
}

View File

@@ -1,10 +1,11 @@
package server package update_channel
import ( import (
"context" "context"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) { func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate" peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil) peersUpdater := NewPeersUpdateManager(nil)
update1 := &UpdateMessage{Update: &proto.SyncResponse{ update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: 0, Serial: 0,
}, },
@@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1) peersUpdater.SendUpdate(context.Background(), peer, update1)
} }
update2 := &UpdateMessage{Update: &proto.SyncResponse{ update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: 10, Serial: 10,
}, },

View File

@@ -0,0 +1,9 @@
package network_map
import (
"github.com/netbirdio/netbird/shared/management/proto"
)
type UpdateMessage struct {
Update *proto.SyncResponse
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
@@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager()) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
} }
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator()) srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
if err != nil { if err != nil {
log.Fatalf("failed to create management server: %v", err) log.Fatalf("failed to create management server: %v", err)
} }

View File

@@ -6,6 +6,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
@@ -14,9 +18,9 @@ import (
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
) )
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
return Create(s, func() *server.PeersUpdateManager { return Create(s, func() *update_channel.PeersUpdateManager {
return server.NewPeersUpdateManager(s.Metrics()) return update_channel.NewPeersUpdateManager(s.Metrics())
}) })
} }
@@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
}) })
} }
func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager { func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
return Create(s, func() *server.TimeBasedAuthSecretsManager { return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
}) })
} }
@@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return manager.NewEphemeralManager(s.Store(), s.AccountManager()) return manager.NewEphemeralManager(s.Store(), s.AccountManager())
}) })
} }
func (s *BaseServer) NetworkMapController() network_map.Controller {
return Create(s, func() *nmapcontroller.Controller {
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController())
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())
})
}

View File

@@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
func (s *BaseServer) AccountManager() account.Manager { func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager { return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
if err != nil { if err != nil {
log.Fatalf("failed to create account manager: %v", err) log.Fatalf("failed to create account manager: %v", err)
} }

View File

@@ -0,0 +1,352 @@
package grpc
import (
"context"
"fmt"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
Uri: stun.URI,
Protocol: ToResponseProto(stun.Proto),
})
}
var turns []*proto.ProtectedHostConfig
if config.TURNConfig != nil {
for _, turn := range config.TURNConfig.Turns {
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Payload
password = turnCredentials.Signature
} else {
username = turn.Username
password = turn.Password
}
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: turn.URI,
Protocol: ToResponseProto(turn.Proto),
},
User: username,
Password: password,
})
}
}
var relayCfg *proto.RelayConfig
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
relayCfg = &proto.RelayConfig{
Urls: config.Relay.Addresses,
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
var signalCfg *proto.HostConfig
if config.Signal != nil {
signalCfg = &proto.HostConfig{
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
}
}
nbConfig := &proto.NetbirdConfig{
Stuns: stuns,
Turns: turns,
Signal: signalCfg,
Relay: relayCfg,
}
return nbConfig
}
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled,
}
}
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
},
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if networkMap.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
for _, rule := range networkMap.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
response.NetworkMap.ForwardingRules = forwardingRules
}
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
return proto.HostConfig_UDP
case nbconfig.DTLS:
return proto.HostConfig_DTLS
case nbconfig.HTTP:
return proto.HostConfig_HTTP
case nbconfig.HTTPS:
return proto.HostConfig_HTTPS
case nbconfig.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result[i] = fwRule
}
return result
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}

View File

@@ -0,0 +1,150 @@
package grpc
import (
"fmt"
"net/netip"
"reflect"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache cache.DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &cache.DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}

View File

@@ -1,4 +1,4 @@
package server package grpc
import ( import (
"hash/fnv" "hash/fnv"

View File

@@ -1,4 +1,4 @@
package server package grpc
import ( import (
"hash/fnv" "hash/fnv"

View File

@@ -1,4 +1,4 @@
package server package grpc
import ( import (
"context" "context"
@@ -22,7 +22,7 @@ import (
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/peers/ephemeral"
@@ -51,13 +51,13 @@ const (
defaultSyncLim = 1000 defaultSyncLim = 1000
) )
// GRPCServer an instance of a Management gRPC API server // Server an instance of a Management gRPC API server
type GRPCServer struct { type Server struct {
accountManager account.Manager accountManager account.Manager
settingsManager settings.Manager settingsManager settings.Manager
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager network_map.PeersUpdateManager
config *nbconfig.Config config *nbconfig.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
@@ -69,23 +69,27 @@ type GRPCServer struct {
blockPeersWithSameConfig bool blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
loginFilter *loginFilter
networkMapController network_map.Controller
syncSem atomic.Int32 syncSem atomic.Int32
syncLim int32 syncLim int32
} }
// NewServer creates a new Management server // NewServer creates a new Management server
func NewServer( func NewServer(
ctx context.Context,
config *nbconfig.Config, config *nbconfig.Config,
accountManager account.Manager, accountManager account.Manager,
settingsManager settings.Manager, settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager, peersUpdateManager network_map.PeersUpdateManager,
secretsManager SecretsManager, secretsManager SecretsManager,
appMetrics telemetry.AppMetrics, appMetrics telemetry.AppMetrics,
ephemeralManager ephemeral.Manager, ephemeralManager ephemeral.Manager,
authManager auth.Manager, authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator, integratedPeerValidator integrated_validator.IntegratedValidator,
) (*GRPCServer, error) { networkMapController network_map.Controller,
) (*Server, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -94,7 +98,7 @@ func NewServer(
if appMetrics != nil { if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams // update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels)) return int64(peersUpdateManager.CountStreams())
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -115,7 +119,7 @@ func NewServer(
} }
} }
return &GRPCServer{ return &Server{
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
@@ -129,12 +133,15 @@ func NewServer(
logBlockedPeers: logBlockedPeers, logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig, blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator, integratedPeerValidator: integratedPeerValidator,
networkMapController: networkMapController,
loginFilter: newLoginFilter(),
syncLim: syncLim, syncLim: syncLim,
}, nil }, nil
} }
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
ip := "" ip := ""
p, ok := peer.FromContext(ctx) p, ok := peer.FromContext(ctx)
if ok { if ok {
@@ -171,7 +178,7 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account) // notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncSem.Load() >= s.syncLim { if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
} }
@@ -191,7 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
sRealIP := realIP.String() sRealIP := realIP.String()
peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) peerMeta := extractPeerMeta(ctx, syncReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP) metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) { if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
} }
@@ -245,35 +252,29 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
} }
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
return mapError(ctx, err) return mapError(ctx, err)
} }
log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort)
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1) s.syncSem.Add(-1)
return err return err
} }
log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart))
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart))
s.ephemeralManager.OnPeerConnected(ctx, peer) s.ephemeralManager.OnPeerConnected(ctx, peer)
log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart))
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart))
if s.appMetrics != nil { if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
} }
@@ -281,15 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
unlock() unlock()
unlock = nil unlock = nil
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
s.syncSem.Add(-1) s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
} }
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
for { for {
select { select {
@@ -323,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
// sendUpdate encrypts the update message using the peer key and the server's wireguard key, // sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server. // then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil { if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer) s.cancelPeerRoutines(ctx, accountID, peer)
@@ -341,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
return nil return nil
} }
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key) unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock() defer unlock()
@@ -356,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
} }
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.authManager == nil { if s.authManager == nil {
return "", status.Errorf(codes.Internal, "missing auth manager") return "", status.Errorf(codes.Internal, "missing auth manager")
} }
@@ -390,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
return userAuth.UserId, nil return userAuth.UserId, nil
} }
func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID)
start := time.Now() start := time.Now()
@@ -498,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
} }
} }
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
@@ -517,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa
// In case it is, the login is successful // In case it is, the login is successful
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful // In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now() reqStart := time.Now()
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
sRealIP := realIP.String() sRealIP := realIP.String()
@@ -531,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP) metahashed := metaHash(peerMeta, sRealIP)
if !s.accountManager.AllowSync(peerKey.String(), metahashed) { if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers { if s.logBlockedPeers {
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
} }
@@ -628,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
}, nil }, nil
} }
func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token var relayToken *Token
var err error var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -647,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings), PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings),
Checks: toProtocolChecks(ctx, postureChecks), Checks: toProtocolChecks(ctx, postureChecks),
} }
@@ -659,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
// //
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already // The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register. // registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := "" userID := ""
if loginReq.GetJwtToken() != "" { if loginReq.GetJwtToken() != "" {
var err error var err error
@@ -679,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR
return userID, nil return userID, nil
} }
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
return proto.HostConfig_UDP
case nbconfig.DTLS:
return proto.HostConfig_DTLS
case nbconfig.HTTP:
return proto.HostConfig_HTTP
case nbconfig.HTTPS:
return proto.HostConfig_HTTPS
case nbconfig.TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
Uri: stun.URI,
Protocol: ToResponseProto(stun.Proto),
})
}
var turns []*proto.ProtectedHostConfig
if config.TURNConfig != nil {
for _, turn := range config.TURNConfig.Turns {
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Payload
password = turnCredentials.Signature
} else {
username = turn.Username
password = turn.Password
}
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: turn.URI,
Protocol: ToResponseProto(turn.Proto),
},
User: username,
Password: password,
})
}
}
var relayCfg *proto.RelayConfig
if config.Relay != nil && len(config.Relay.Addresses) > 0 {
relayCfg = &proto.RelayConfig{
Urls: config.Relay.Addresses,
}
if relayToken != nil {
relayCfg.TokenPayload = relayToken.Payload
relayCfg.TokenSignature = relayToken.Signature
}
}
var signalCfg *proto.HostConfig
if config.Signal != nil {
signalCfg = &proto.HostConfig{
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
}
}
nbConfig := &proto.NetbirdConfig{
Stuns: stuns,
Turns: turns,
Signal: signalCfg,
Relay: relayCfg,
}
return nbConfig
}
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled,
}
}
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
},
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
response.RemotePeers = remotePeers
response.NetworkMap.RemotePeers = remotePeers
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
if networkMap.ForwardingRules != nil {
forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules))
for _, rule := range networkMap.ForwardingRules {
forwardingRules = append(forwardingRules, rule.ToProto())
}
response.NetworkMap.ForwardingRules = forwardingRules
}
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// IsHealthy indicates whether the service is healthy // IsHealthy indicates whether the service is healthy
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
return &proto.Empty{}, nil return &proto.Empty{}, nil
} }
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error {
var err error var err error
var turnToken *Token var turnToken *Token
@@ -862,19 +708,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID) peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed to get peer groups %s", err) return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
} }
// Get all peers in the account for forwarder port computation plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "")
if err != nil {
return fmt.Errorf("get account peers: %w", err)
}
dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion)
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {
@@ -899,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
// GetDeviceAuthorizationFlow returns a device authorization flow information // GetDeviceAuthorizationFlow returns a device authorization flow information
// This is used for initiating an Oauth 2 device authorization grant flow // This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -957,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information // GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow // This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login // which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey)
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -1012,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En
// SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected, // SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected,
// peer's under the same account of any updates. // peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
@@ -1037,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
return &proto.Empty{}, nil return &proto.Empty{}, nil
} }
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now() start := time.Now()

View File

@@ -0,0 +1,106 @@
package grpc
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
wgKey: testingServerKey,
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}

View File

@@ -1,4 +1,4 @@
package server package grpc
import ( import (
"context" "context"
@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct {
relayCfg *nbconfig.Relay relayCfg *nbconfig.Relay
turnHmacToken *auth.TimedHMAC turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager network_map.PeersUpdateManager
settingsManager settings.Manager settingsManager settings.Manager
groupsManager groups.Manager groupsManager groups.Manager
turnCancelMap map[string]chan struct{} turnCancelMap map[string]chan struct{}
@@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct {
type Token auth.Token type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{ mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager, updateManager: updateManager,
turnCfg: turnCfg, turnCfg: turnCfg,
@@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
m.extendNetbirdConfig(ctx, peerID, accountID, update) m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
} }
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
@@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
m.extendNetbirdConfig(ctx, peerID, accountID, update) m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
} }
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {

View File

@@ -1,4 +1,4 @@
package server package grpc
import ( import (
"context" "context"
@@ -13,6 +13,8 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour} ttl := util.Duration{Duration: time.Hour}
secret := "some_secret" secret := "some_secret"
peersManager := NewPeersUpdateManager(nil) peersManager := update_channel.NewPeersUpdateManager(nil)
rc := &config.Relay{ rc := &config.Relay{
Addresses: []string{"localhost:0"}, Addresses: []string{"localhost:0"},
@@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second} ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret" secret := "some_secret"
peersManager := NewPeersUpdateManager(nil) peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer" peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer) updateChannel := peersManager.CreateChannel(context.Background(), peer)
@@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
t.Errorf("expecting peer to be present in the relay cancel map, got not present") t.Errorf("expecting peer to be present in the relay cancel map, got not present")
} }
var updates []*UpdateMessage var updates []*network_map.UpdateMessage
loop: loop:
for timeout := time.After(5 * time.Second); ; { for timeout := time.After(5 * time.Second); ; {
@@ -185,7 +187,7 @@ loop:
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour} ttl := util.Duration{Duration: time.Hour}
secret := "some_secret" secret := "some_secret"
peersManager := NewPeersUpdateManager(nil) peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer" peer := "some_peer"
rc := &config.Relay{ rc := &config.Relay{

View File

@@ -11,10 +11,8 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"slices" "slices"
"strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
cacheStore "github.com/eko/gocache/lib/v4/store" cacheStore "github.com/eko/gocache/lib/v4/store"
@@ -26,6 +24,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -53,9 +52,6 @@ const (
peerSchedulerRetryInterval = 3 * time.Second peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims" emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
) )
type userLoggedInOnce bool type userLoggedInOnce bool
@@ -71,7 +67,7 @@ type DefaultAccountManager struct {
cacheMux sync.Mutex cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{} cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager networkMapController network_map.Controller
idpManager idp.Manager idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache externalCacheManager nbcache.UserDataCache
@@ -91,8 +87,7 @@ type DefaultAccountManager struct {
singleAccountMode bool singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup // singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler peerLoginExpiry Scheduler
peerInactivityExpiry Scheduler peerInactivityExpiry Scheduler
@@ -106,19 +101,11 @@ type DefaultAccountManager struct {
permissionsManager permissions.Manager permissionsManager permissions.Manager
accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
loginFilter *loginFilter
disableDefaultPolicy bool disableDefaultPolicy bool
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
} }
var _ account.Manager = (*DefaultAccountManager)(nil)
func isUniqueConstraintError(err error) bool { func isUniqueConstraintError(err error) bool {
switch { switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"), case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
@@ -185,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
func BuildManager( func BuildManager(
ctx context.Context, ctx context.Context,
store store.Store, store store.Store,
peersUpdateManager *PeersUpdateManager, networkMapController network_map.Controller,
idpManager idp.Manager, idpManager idp.Manager,
singleAccountModeDomain string, singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store, eventStore activity.Store,
geo geolocation.Geolocation, geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool, userDeleteFromIDPEnabled bool,
@@ -204,27 +190,14 @@ func BuildManager(
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
}() }()
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
geo: geo, geo: geo,
peersUpdateManager: peersUpdateManager, networkMapController: networkMapController,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{}, cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore, eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
peerInactivityExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(),
@@ -235,15 +208,10 @@ func BuildManager(
proxyController: proxyController, proxyController: proxyController,
settingsManager: settingsManager, settingsManager: settingsManager,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy, disableDefaultPolicy: disableDefaultPolicy,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
} }
am.startWarmup(ctx) am.networkMapController.StartWarmup(ctx)
accountsCounter, err := store.GetAccountsCounter(ctx) accountsCounter, err := store.GetAccountsCounter(ctx)
if err != nil { if err != nil {
@@ -291,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
am.ephemeralManager = em am.ephemeralManager = em
} }
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
var initialInterval int64
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
interval, err := strconv.Atoi(intervalStr)
if err != nil {
initialInterval = 1
log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err)
} else {
initialInterval = int64(interval) * 10
go func() {
startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S")
startupPeriod, err := strconv.Atoi(startupPeriodStr)
if err != nil {
startupPeriod = 1
log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err)
}
time.Sleep(time.Duration(startupPeriod) * time.Second)
am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond))
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval)
}()
}
am.updateAccountPeersBufferInterval.Store(initialInterval)
log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval)
}
func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager {
return am.externalCacheManager return am.externalCacheManager
} }
@@ -419,9 +361,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
} }
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
go am.UpdateAccountPeers(ctx, accountID) go am.UpdateAccountPeers(ctx, accountID)
} }
@@ -1504,10 +1443,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
if removedGroupAffectsPeers || newGroupsAffectsPeers { if removedGroupAffectsPeers || newGroupsAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
return err
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
} }
@@ -1667,14 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
} }
func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
return am.loginFilter.allowLogin(wgPubKey, metahash) peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
@@ -1682,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
} }
metahash := metaHash(meta, realIP.String()) return peer, netMap, postureChecks, dnsfwdPort, nil
am.loginFilter.addLogin(peerPubKey, metahash)
return peer, netMap, postureChecks, nil
} }
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
@@ -1702,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return err return err
} }
_, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) _, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil { if err != nil {
return mapError(ctx, err) return err
} }
return nil return nil
} }
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
}
// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
return am.peersUpdateManager.HasChannel(peerID)
}
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool { func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain) return invalidDomainRegexp.MatchString(domain)
} }
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return am.dnsDomain
}
if settings.DNSDomain == "" {
return am.dnsDomain
}
return settings.DNSDomain
}
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {
peers := []*nbpeer.Peer{} peers := []*nbpeer.Peer{}
log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID) log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID)
@@ -2159,8 +2065,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil { if err != nil {
return err return err
} }
am.updatePeerInNetworkMapCache(peer.AccountID, peer) am.networkMapController.OnPeerUpdated(peer.AccountID, peer)
am.BufferUpdateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
@@ -2208,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
if err != nil { if err != nil {
return fmt.Errorf("get account settings: %w", err) return fmt.Errorf("get account settings: %w", err)
} }
dnsDomain := am.GetDNSDomain(settings) dnsDomain := am.networkMapController.GetDNSDomain(settings)
eventMeta := peer.EventMeta(dnsDomain) eventMeta := peer.EventMeta(dnsDomain)
oldIP := peer.IP.String() oldIP := peer.IP.String()

View File

@@ -89,7 +89,6 @@ type Manager interface {
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain(settings *types.Settings) string
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
@@ -97,10 +96,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
@@ -110,7 +107,7 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@@ -127,6 +124,4 @@ type Manager interface {
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager) SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
} }

View File

@@ -0,0 +1,11 @@
package account
import (
"context"
"github.com/netbirdio/netbird/management/server/types"
)
type RequestBuffer interface {
GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error)
}

View File

@@ -22,6 +22,9 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbAccount "github.com/netbirdio/netbird/management/server/account" nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/cache"
@@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) {
} }
func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
@@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id" userId := "user-id"
domain := "test.domain" domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain, false) _ = newAccountWithId(context.Background(), "", userId, domain, false)
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed") require.NoError(t, err, "create init user failed")
@@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
} }
func TestAccountManager_PrivateAccount(t *testing.T) { func TestAccountManager_PrivateAccount(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
} }
func TestAccountManager_SetOrUpdateDomain(t *testing.T) { func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
} }
func TestAccountManager_GetAccountByUserID(t *testing.T) { func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string)
} }
func TestAccountManager_GetAccount(t *testing.T) { func TestAccountManager_GetAccount(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
} }
func TestAccountManager_DeleteAccount(t *testing.T) { func TestAccountManager_DeleteAccount(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
DomainCategory: types.PublicCategory, DomainCategory: types.PublicCategory,
} }
am, err := createManager(b) am, _, err := createManager(b)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
return return
@@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User {
} }
func TestAccountManager_AddPeer(t *testing.T) { func TestAccountManager_AddPeer(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
} }
func TestAccountManager_AddPeerWithUserID(t *testing.T) { func TestAccountManager_AddPeerWithUserID(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1155,7 +1158,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
} }
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t) testAccountManager_NetworkUpdates_SaveGroup(t)
} }
@@ -1164,7 +1167,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
} }
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
group := types.Group{ group := types.Group{
ID: "groupA", ID: "groupA",
@@ -1190,8 +1193,8 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
}, true) }, true)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@@ -1215,7 +1218,7 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
} }
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t) testAccountManager_NetworkUpdates_DeletePolicy(t)
} }
@@ -1224,10 +1227,10 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
} }
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t) manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer updateManager.CloseChannel(context.Background(), peer1.ID)
// Ensure that we do not receive an update message before the policy is deleted // Ensure that we do not receive an update message before the policy is deleted
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -1258,7 +1261,7 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
} }
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t) testAccountManager_NetworkUpdates_SavePolicy(t)
} }
@@ -1267,7 +1270,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
} }
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{ group := types.Group{
AccountID: account.Id, AccountID: account.Id,
@@ -1280,8 +1283,8 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
return return
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@@ -1316,7 +1319,7 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
} }
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t) testAccountManager_NetworkUpdates_DeletePeer(t)
} }
@@ -1325,7 +1328,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
} }
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := types.Group{ group := types.Group{
ID: "groupA", ID: "groupA",
@@ -1354,8 +1357,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return return
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) // We need to sleep to wait for the buffer peer update
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) time.Sleep(300 * time.Millisecond)
updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer updateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
@@ -1378,7 +1384,7 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
} }
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t) testAccountManager_NetworkUpdates_DeleteGroup(t)
} }
@@ -1387,10 +1393,10 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
} }
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) defer updateManager.CloseChannel(context.Background(), peer1.ID)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA", ID: "groupA",
@@ -1457,7 +1463,7 @@ func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
} }
func TestAccountManager_DeletePeer(t *testing.T) { func TestAccountManager_DeletePeer(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1538,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy
} }
func TestGetUsersFromAccount(t *testing.T) { func TestGetUsersFromAccount(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1837,7 +1843,7 @@ func hasNilField(x interface{}) error {
} }
func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1852,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
} }
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1908,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
} }
func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -1951,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
} }
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -2013,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
} }
func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -2677,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
func TestAccount_SetJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", "postgres") t.Setenv("NETBIRD_STORE_ENGINE", "postgres")
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
// create a new account // create a new account
@@ -2919,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
// Fatalf(format string, args ...interface{}) // Fatalf(format string, args ...interface{})
// } // }
func createManager(t testing.TB) (*DefaultAccountManager, error) { func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper() t.Helper()
store, err := createStore(t) store, err := createStore(t)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
@@ -2948,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return manager, nil return manager, updateManager, nil
} }
func createStore(t testing.TB) (store.Store, error) { func createStore(t testing.TB) (store.Store, error) {
@@ -2982,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
} }
} }
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
t.Helper() t.Helper()
manager, err := createManager(t) manager, updateManager, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -3026,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
peer2 := getPeer(manager, setupKey) peer2 := getPeer(manager, setupKey)
peer3 := getPeer(manager, setupKey) peer3 := getPeer(manager, setupKey)
return manager, account, peer1, peer2, peer3 return manager, updateManager, account, peer1, peer2, peer3
} }
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper() t.Helper()
select { select {
case msg := <-updateMessage: case msg := <-updateMessage:
@@ -3039,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag
} }
} }
func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper() t.Helper()
select { select {
@@ -3077,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)
} }
@@ -3086,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("Failed to get account: %v", err) b.Fatalf("Failed to get account: %v", err)
} }
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers { for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) updateManager.CreateChannel(ctx, peerID)
} }
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
assert.NoError(b, err) assert.NoError(b, err)
} }
@@ -3140,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)
} }
@@ -3149,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("Failed to get account: %v", err) b.Fatalf("Failed to get account: %v", err)
} }
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers { for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) updateManager.CreateChannel(ctx, peerID)
} }
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
@@ -3210,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)
} }
@@ -3219,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
if err != nil { if err != nil {
b.Fatalf("Failed to get account: %v", err) b.Fatalf("Failed to get account: %v", err)
} }
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers { for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) updateManager.CreateChannel(ctx, peerID)
} }
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
@@ -3282,7 +3289,7 @@ func TestMain(m *testing.M) {
} }
func Test_GetCreateAccountByPrivateDomain(t *testing.T) { func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -3328,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
} }
func Test_UpdateToPrimaryAccount(t *testing.T) { func Test_UpdateToPrimaryAccount(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -3358,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
} }
func TestDefaultAccountManager_IsCacheCold(t *testing.T) { func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err) require.NoError(t, err)
t.Run("memory cache", func(t *testing.T) { t.Run("memory cache", func(t *testing.T) {
@@ -3408,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
} }
func TestPropagateUserGroupMemberships(t *testing.T) { func TestPropagateUserGroupMemberships(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err) require.NoError(t, err)
ctx := context.Background() ctx := context.Background()
@@ -3525,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
} }
func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err) require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3557,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
} }
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err) require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
@@ -3596,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
} }
func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
@@ -3663,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
} }
func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -3709,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
} }
func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -3,54 +3,23 @@ package server
import ( import (
"context" "context"
"slices" "slices"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
dnsForwarderPort = nbdns.ForwarderServerPort dnsForwarderPort = nbdns.ForwarderServerPort
oldForwarderPort = nbdns.ForwarderClientPort
) )
const dnsForwarderPortMinVersion = "v0.59.0"
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
NameServerGroups sync.Map
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
@@ -117,9 +86,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -194,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return validateGroups(settings.DisabledManagementGroups, groups) return validateGroups(settings.DisabledManagementGroups, groups)
} }
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
if len(peers) == 0 {
return int64(oldForwarderPort)
}
reqVer := semver.Canonical(requiredVersion)
// Check if all peers have the required version or newer
for _, peer := range peers {
// Development version is always supported
if peer.Meta.WtVersion == "development" {
continue
}
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
if peerVersion == "" {
// If any peer doesn't have version info, return 0
return int64(oldForwarderPort)
}
// Compare versions
if semver.Compare(peerVersion, reqVer) < 0 {
return int64(oldForwarderPort)
}
}
// All peers have the required version or newer
return int64(dnsForwarderPort)
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}

View File

@@ -2,9 +2,7 @@ package server
import ( import (
"context" "context"
"fmt"
"net/netip" "net/netip"
"reflect"
"testing" "testing"
"time" "time"
@@ -12,6 +10,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock())
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createDNSStore(t *testing.T) (store.Store, error) { func createDNSStore(t *testing.T) (store.Store, error) {
@@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return am.Store.GetAccount(context.Background(), account.Id) return am.Store.GetAccount(context.Background(), account.Id)
} }
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort))
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func TestComputeForwarderPort(t *testing.T) {
// Test with empty peers list
peers := []*nbpeer.Peer{}
result := computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
}
// Test with peers that have old versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.26.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
}
// Test with peers that have new versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(dnsForwarderPort) {
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
}
// Test with peers that have mixed versions
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.59.0",
},
},
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "0.57.0",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
}
// Test with peers that have empty version
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
}
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "development",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result == int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
}
// Test with peers that have unknown version string
peers = []*nbpeer.Peer{
{
Meta: nbpeer.PeerSystemMeta{
WtVersion: "unknown",
},
},
}
result = computeForwarderPort(peers, "v0.59.0")
if result != int64(oldForwarderPort) {
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
}
}
func TestDNSAccountPeersUpdate(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{
{ {
@@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates

View File

@@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
} }
func TestDefaultAccountManager_GetEvents(t *testing.T) { func TestDefaultAccountManager_GetEvents(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
return return
} }

View File

@@ -114,9 +114,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -185,9 +182,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -256,9 +250,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -327,9 +318,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -376,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil return nil
} }
dnsDomain := am.GetDNSDomain(settings) dnsDomain := am.networkMapController.GetDNSDomain(settings)
for _, peerID := range addedPeers { for _, peerID := range addedPeers {
peer, ok := peers[peerID] peer, ok := peers[peerID]
@@ -493,9 +481,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -534,9 +519,6 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -565,9 +547,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -606,9 +585,6 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
} }
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }

View File

@@ -37,7 +37,7 @@ const (
) )
func TestDefaultAccountManager_CreateGroup(t *testing.T) { func TestDefaultAccountManager_CreateGroup(t *testing.T) {
am, err := createManager(t) am, _, err := createManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
} }
func TestDefaultAccountManager_DeleteGroup(t *testing.T) { func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
am, err := createManager(t) am, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatalf("failed to create account manager: %s", err) t.Fatalf("failed to create account manager: %s", err)
} }
@@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
} }
func TestDefaultAccountManager_DeleteGroups(t *testing.T) { func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
am, err := createManager(t) am, _, err := createManager(t)
assert.NoError(t, err, "Failed to create account manager") assert.NoError(t, err, "Failed to create account manager")
manager, account, err := initTestGroupAccount(am) manager, account, err := initTestGroupAccount(am)
@@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
} }
func TestGroupAccountPeersUpdate(t *testing.T) { func TestGroupAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{ g := []*types.Group{
{ {
@@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
// Saving a group that is not linked to any resource should not update account peers // Saving a group that is not linked to any resource should not update account peers
@@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
} }
func Test_AddPeerToGroup(t *testing.T) { func Test_AddPeerToGroup(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) {
} }
func Test_AddPeerToAll(t *testing.T) { func Test_AddPeerToAll(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) {
} }
func Test_AddPeerAndAddToAll(t *testing.T) { func Test_AddPeerAndAddToAll(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP {
} }
func Test_IncrementNetworkSerial(t *testing.T) { func Test_IncrementNetworkSerial(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -1,39 +0,0 @@
package server
import (
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
a := am.holder.GetAccount(account.Id)
if a == nil {
am.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
am.holder.AddAccount(account)
}
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
return am.holder.GetAccount(accountID)
}
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
a := am.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
am.holder.AddAccount(account)
}

View File

@@ -13,6 +13,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -65,6 +66,7 @@ func NewAPIHandler(
permissionsManager permissions.Manager, permissionsManager permissions.Manager,
peersManager nbpeers.Manager, peersManager nbpeers.Manager,
settingsManager settings.Manager, settingsManager settings.Manager,
networkMapController network_map.Controller,
) (http.Handler, error) { ) (http.Handler, error) {
var rateLimitingConfig *middleware.RateLimiterConfig var rateLimitingConfig *middleware.RateLimiterConfig
@@ -120,7 +122,7 @@ func NewAPIHandler(
} }
accounts.AddEndpoints(accountManager, settingsManager, router) accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router) peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router) users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router) policies.AddEndpoints(accountManager, LocationManager, router)

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
@@ -23,11 +24,12 @@ import (
// Handler is a handler that returns peers of the account // Handler is a handler that returns peers of the account
type Handler struct { type Handler struct {
accountManager account.Manager accountManager account.Manager
networkMapController network_map.Controller
} }
func AddEndpoints(accountManager account.Manager, router *mux.Router) { func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
peersHandler := NewHandler(accountManager) peersHandler := NewHandler(accountManager, networkMapController)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
} }
// NewHandler creates a new peers Handler // NewHandler creates a new peers Handler
func NewHandler(accountManager account.Manager) *Handler { func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
return &Handler{ return &Handler{
accountManager: accountManager, accountManager: accountManager,
networkMapController: networkMapController,
} }
} }
@@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if peer.Status.Connected { if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected // This may happen after server restart when not all peers are yet connected
if !h.accountManager.HasConnectedChannel(peer.ID) { if !h.networkMapController.IsConnected(peer.ID) {
peerToReturn.Status.Connected = false peerToReturn.Status.Connected = false
} }
} }
@@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
return return
} }
dnsDomain := h.accountManager.GetDNSDomain(settings) dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
@@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain(settings) dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil { if err != nil {
@@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain(settings) dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
@@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return return
} }
dnsDomain := h.accountManager.GetDNSDomain(account.Settings) dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)

View File

@@ -14,12 +14,14 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"go.uber.org/mock/gomock"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -36,7 +38,7 @@ const (
serviceUser = "service_user" serviceUser = "service_user"
) )
func initTestMetaData(peers ...*nbpeer.Peer) *Handler { func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers { for _, peer := range peers {
@@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
}, },
} }
ctrl := gomock.NewController(t)
networkMapController := network_map.NewMockController(ctrl)
networkMapController.EXPECT().
GetDNSDomain(gomock.Any()).
Return("domain").
AnyTimes()
networkMapController.EXPECT().
IsConnected(noUpdateChannelTestPeerID).
Return(false).
AnyTimes()
networkMapController.EXPECT().
IsConnected(gomock.Any()).
Return(true).
AnyTimes()
return &Handler{ return &Handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return account.Settings, nil return account.Settings, nil
}, },
}, },
networkMapController: networkMapController,
} }
} }
@@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) {
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
p := initTestMetaData(peer, peer1) p := initTestMetaData(t, peer, peer1)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) {
UserID: regularUser, UserID: regularUser,
} }
p := initTestMetaData(peer1, peer2, peer3) p := initTestMetaData(t, peer1, peer2, peer3)
tt := []struct { tt := []struct {
name string name string
@@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
}, },
} }
p := initTestMetaData(testPeer) p := initTestMetaData(t, testPeer)
tt := []struct { tt := []struct {
name string name string

View File

@@ -10,6 +10,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -31,7 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
) )
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil { if err != nil {
t.Fatalf("Failed to create test store: %v", err) t.Fatalf("Failed to create test store: %v", err)
@@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
t.Fatalf("Failed to create metrics: %v", err) t.Fatalf("Failed to create metrics: %v", err)
} }
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := update_channel.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
done := make(chan struct{}) done := make(chan struct{})
if validateUpdate { if validateUpdate {
@@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
userManager := users.NewManager(store) userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
ctx := context.Background()
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock())
am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil { if err != nil {
t.Fatalf("Failed to create manager: %v", err) t.Fatalf("Failed to create manager: %v", err)
} }
@@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
groupsManagerMock := groups.NewManagerMock() groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager) peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }
@@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
return apiHandler, am, done return apiHandler, am, done
} }
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper() t.Helper()
select { select {
case msg := <-updateMessage: case msg := <-updateMessage:
@@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server
} }
} }
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) {
t.Helper() t.Helper()
select { select {

View File

@@ -22,10 +22,14 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
return loginResp, nil return loginResp, nil
} }
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &GRPCServer{
wgKey: testingServerKey,
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}
func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "localhost:0") lis, err := net.Listen("tcp", "localhost:0")
@@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
t.Fatal(err) t.Fatal(err)
} }
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
@@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := BuildManager(ctx, store, networkMapController, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
@@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := manager.NewEphemeralManager(store, accountManager) ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController)
if err != nil { if err != nil {
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }
@@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) {
peerLogin := types.PeerLogin{ peerLogin := types.PeerLogin{
WireGuardPubKey: key.String(), WireGuardPubKey: key.String(),
SSHKey: "random", SSHKey: "random",
Meta: extractPeerMeta(context.Background(), meta), Meta: nbpeer.PeerSystemMeta{
SetupKey: setupKey.Key, Hostname: meta.GetHostname(),
ConnectionIP: net.IP{1, 1, 1, 1}, GoOS: meta.GetGoOS(),
Kernel: meta.GetKernel(),
Platform: meta.GetPlatform(),
OS: meta.GetOS(),
OSVersion: meta.GetOSVersion(),
WtVersion: meta.GetNetbirdVersion(),
UIVersion: meta.GetUiVersion(),
KernelVersion: meta.GetKernelVersion(),
SystemSerialNumber: meta.GetSysSerialNumber(),
SystemProductName: meta.GetSysProductName(),
SystemManufacturer: meta.GetSysManufacturer(),
Environment: nbpeer.Environment{
Cloud: meta.GetEnvironment().GetCloud(),
Platform: meta.GetEnvironment().GetPlatform(),
},
Flags: nbpeer.Flags{
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
DisableDNS: meta.GetFlags().GetDisableDNS(),
DisableFirewall: meta.GetFlags().GetDisableFirewall(),
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
BlockInbound: meta.GetFlags().GetBlockInbound(),
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
},
},
SetupKey: setupKey.Key,
ConnectionIP: net.IP{1, 1, 1, 1},
} }
login := func() error { login := func() error {

View File

@@ -20,7 +20,10 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -176,7 +179,6 @@ func startServer(
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
} }
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -199,13 +201,18 @@ func startServer(
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(str) permissionsManager := permissions.NewManager(str)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager( accountManager, err := server.BuildManager(
context.Background(), context.Background(),
str, str,
peersUpdateManager, networkMapController,
nil, nil,
"", "",
"netbird.selfhosted",
eventStore, eventStore,
nil, nil,
false, false,
@@ -220,18 +227,18 @@ func startServer(
} }
groupsManager := groups.NewManager(str, permissionsManager, accountManager) groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer( mgmtServer, err := nbgrpc.NewServer(
context.Background(),
config, config,
accountManager, accountManager,
settingsMockManager, settingsMockManager,
peersUpdateManager, updateManager,
secretsManager, secretsManager,
nil, nil,
&manager.EphemeralManager{}, &manager.EphemeralManager{},
nil, nil,
server.MockIntegratedValidator{}, server.MockIntegratedValidator{},
networkMapController,
) )
if err != nil { if err != nil {
t.Fatalf("failed creating management server: %v", err) t.Fatalf("failed creating management server: %v", err)

View File

@@ -38,7 +38,7 @@ type MockAccountManager struct {
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -94,7 +94,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error
@@ -178,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
} }
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil { if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
@@ -747,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
} }
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, accountID) return am.SyncPeerFunc(ctx, sync, accountID)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
} }
// GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface

View File

@@ -83,9 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -137,9 +134,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -183,9 +177,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }

View File

@@ -11,6 +11,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createNSStore(t *testing.T) (store.Store, error) { func createNSStore(t *testing.T) (store.Store, error) {
@@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) {
} }
func TestNameServerAccountPeersUpdate(t *testing.T) { func TestNameServerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupA *nbdns.NameServerGroup
var newNameServerGroupB *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup
@@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
// Creating a nameserver group with a distribution group no peers should not update account peers // Creating a nameserver group with a distribution group no peers should not update account peers

View File

@@ -1,80 +0,0 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
am.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (am *DefaultAccountManager) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := am.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
}
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := am.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
am.updateAccountInHolder(account)
}
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if am.experimentalNetworkMap(accountId) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
am.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
_, ok := am.expNewNetworkMapAIDs[accountId]
return am.expNewNetworkMap || ok
}

View File

@@ -177,9 +177,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event() event()
} }
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil

View File

@@ -157,9 +157,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event() event()
} }
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil return resource, nil
@@ -260,9 +257,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
event() event()
} }
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil return resource, nil
@@ -337,9 +331,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event() event()
} }
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil

View File

@@ -119,9 +119,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil return router, nil
@@ -186,9 +183,6 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil return router, nil
@@ -223,9 +217,6 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event() event()
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil

View File

@@ -8,8 +8,6 @@ import (
"net" "net"
"slices" "slices"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
@@ -23,7 +21,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -31,7 +28,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -140,12 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if expired { if expired {
if am.experimentalNetworkMap(accountID) { am.networkMapController.OnPeerUpdated(accountID, peer)
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.BufferUpdateAccountPeers(ctx, accountID)
} }
return nil return nil
@@ -201,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
var peer *nbpeer.Peer var peer *nbpeer.Peer
var settings *types.Settings var settings *types.Settings
var peerGroupList []string var peerGroupList []string
var requiresPeerUpdates bool
var peerLabelChanged bool var peerLabelChanged bool
var sshChanged bool var sshChanged bool
var loginExpirationChanged bool var loginExpirationChanged bool
@@ -224,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err return err
} }
dnsDomain = am.GetDNSDomain(settings) dnsDomain = am.networkMapController.GetDNSDomain(settings)
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
if err != nil { if err != nil {
return err return err
} }
@@ -319,15 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
} }
if am.experimentalNetworkMap(accountID) { am.networkMapController.OnPeerUpdated(accountID, peer)
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
if peerLabelChanged || requiresPeerUpdates {
am.UpdateAccountPeers(ctx, accountID)
} else if sshChanged {
am.UpdateAccountPeer(ctx, accountID, peer.ID)
}
return peer, nil return peer, nil
} }
@@ -383,20 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent() storeEvent()
} }
if am.experimentalNetworkMap(accountID) { err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil {
if err != nil { log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
return err
}
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
}
} }
if userID != activity.SystemInitiator { if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil {
am.BufferUpdateAccountPeers(ctx, accountID) log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
} }
return nil return nil
@@ -404,47 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID) return am.networkMapController.GetNetworkMap(ctx, peerID)
if err != nil {
return nil, err
}
peer := account.GetPeer(peerID)
if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
}
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, err
}
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(peer.AccountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return networkMap, nil
} }
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
@@ -703,27 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
} }
opEvent.TargetID = newPeer.ID opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings))
if !addedByUser { if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName opEvent.Meta["setup_key_name"] = setupKeyName
} }
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if am.experimentalNetworkMap(accountID) { if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
if err != nil {
return nil, nil, nil, err
}
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
} }
am.BufferUpdateAccountPeers(ctx, accountID) p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
return p, nmap, pc, err
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
} }
func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
@@ -738,7 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
} }
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer var peer *nbpeer.Peer
var peerNotValid bool var peerNotValid bool
var isStatusChanged bool var isStatusChanged bool
@@ -748,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, 0, err
} }
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -798,17 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil return nil
}) })
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, 0, err
} }
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
if am.experimentalNetworkMap(accountID) { am.networkMapController.OnPeerUpdated(accountID, peer)
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
am.BufferUpdateAccountPeers(ctx, accountID)
} }
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
} }
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
@@ -933,15 +857,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
if am.experimentalNetworkMap(accountID) { am.networkMapController.OnPeerUpdated(accountID, peer)
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
startBuffer := time.Now()
am.BufferUpdateAccountPeers(ctx, accountID)
log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer))
} }
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
return p, nmap, pc, err
} }
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
@@ -1033,68 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, nil
}
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
startPosture := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, err
}
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
return peer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer) err := checkAuth(ctx, user.Id, peer)
if err != nil { if err != nil {
@@ -1118,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
return nil return nil
} }
@@ -1214,232 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// UpdateAccountPeers updates all peers that belong to an account. // UpdateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) _ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
}
globalStart := time.Now()
hasPeersConnected := false
for _, peer := range account.Peers {
if am.peersUpdateManager.HasChannel(peer.ID) {
hasPeersConnected = true
break
}
}
if !hasPeersConnected {
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err)
return
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if am.experimentalNetworkMap(accountID) {
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
}
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err)
return
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
for _, peer := range account.Peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return
}
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
//
wg.Wait()
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
}
}
type bufferUpdate struct {
mu sync.Mutex
next *time.Timer
update atomic.Bool
} }
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
am.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() {
am.UpdateAccountPeers(ctx, accountID)
})
return
}
b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load()))
}()
} }
// UpdateAccountPeer updates a single peer that belongs to an account. // UpdateAccountPeer updates a single peer that belongs to an account.
// Should be called when changes need to be synced to a specific peer only. // Should be called when changes need to be synced to a specific peer only.
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
if !am.peersUpdateManager.HasChannel(peerId) { _ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId)
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
return
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
return
}
peer := account.GetPeer(peerId)
if peer == nil {
log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId)
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
return
}
dnsCache := &DNSConfigCache{}
dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
postureChecks, err := am.getPeerPostureChecks(account, peerId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err)
return
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return
}
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
return
}
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
} }
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1594,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err != nil { if err != nil {
return nil, err return nil, err
} }
dnsDomain := am.GetDNSDomain(settings) dnsDomain := am.networkMapController.GetDNSDomain(settings)
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion)
for _, peer := range peers { for _, peer := range peers {
if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil {
@@ -1635,24 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err return nil, err
} }
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() { peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
}) })
@@ -1661,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return peerDeletedEvents, nil return peerDeletedEvents, nil
} }
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {
labelMap[label] = struct{}{}
}
return labelMap
}
// validatePeerDelete checks if the peer can be deleted. // validatePeerDelete checks if the peer can be deleted.
func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error { func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error {
linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId) linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId)

View File

@@ -13,7 +13,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -25,10 +24,14 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
@@ -172,12 +175,12 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
} }
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t) testGetNetworkMapGeneral(t)
} }
func testGetNetworkMapGeneral(t *testing.T) { func testGetNetworkMapGeneral(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -249,7 +252,7 @@ func testGetNetworkMapGeneral(t *testing.T) {
func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
// TODO: disable until we start use policy again // TODO: disable until we start use policy again
t.Skip() t.Skip()
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -426,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
} }
func TestAccountManager_GetPeerNetwork(t *testing.T) { func TestAccountManager_GetPeerNetwork(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -487,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
} }
func TestDefaultAccountManager_GetPeer(t *testing.T) { func TestDefaultAccountManager_GetPeer(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -674,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -742,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
} }
} }
func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) { func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) {
b.Helper() b.Helper()
manager, err := createManager(b) manager, updateManager, err := createManager(b)
if err != nil { if err != nil {
return nil, "", "", err return nil, nil, "", "", err
} }
accountID := "test_account" accountID := "test_account"
@@ -798,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
ips := account.GetTakenIPs() ips := account.GetTakenIPs()
peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) peerIP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil { if err != nil {
return nil, "", "", err return nil, nil, "", "", err
} }
peerKey, _ := wgtypes.GeneratePrivateKey() peerKey, _ := wgtypes.GeneratePrivateKey()
@@ -904,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
err = manager.Store.SaveAccount(context.Background(), account) err = manager.Store.SaveAccount(context.Background(), account)
if err != nil { if err != nil {
return nil, "", "", err return nil, nil, "", "", err
} }
return manager, accountID, regularUser, nil return manager, updateManager, accountID, regularUser, nil
} }
func BenchmarkGetPeers(b *testing.B) { func BenchmarkGetPeers(b *testing.B) {
@@ -928,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) {
defer log.SetOutput(os.Stderr) defer log.SetOutput(os.Stderr)
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)
} }
@@ -968,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)
} }
@@ -980,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
b.Fatalf("Failed to get account: %v", err) b.Fatalf("Failed to get account: %v", err)
} }
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers { for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) updateManager.CreateChannel(ctx, peerID)
} }
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer() b.ResetTimer()
start := time.Now() start := time.Now()
@@ -1013,7 +1012,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
} }
func TestUpdateAccountPeers_Experimental(t *testing.T) { func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t) testUpdateAccountPeers(t)
} }
@@ -1037,7 +1036,7 @@ func testUpdateAccountPeers(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups)
if err != nil { if err != nil {
t.Fatalf("Failed to setup test account manager: %v", err) t.Fatalf("Failed to setup test account manager: %v", err)
} }
@@ -1049,13 +1048,12 @@ func testUpdateAccountPeers(t *testing.T) {
t.Fatalf("Failed to get account: %v", err) t.Fatalf("Failed to get account: %v", err)
} }
peerChannels := make(map[string]chan *UpdateMessage) peerChannels := make(map[string]chan *network_map.UpdateMessage)
for peerID := range account.Peers { for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID)
} }
manager.peersUpdateManager.peerChannels = peerChannels
manager.UpdateAccountPeers(ctx, account.Id) manager.UpdateAccountPeers(ctx, account.Id)
for _, channel := range peerChannels { for _, channel := range peerChannels {
@@ -1097,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) {
DNSLabel: "peer1", DNSLabel: "peer1",
SSHKey: "peer1-ssh-key", SSHKey: "peer1-ssh-key",
} }
turnRelayToken := &Token{ turnRelayToken := &grpc.Token{
Payload: "turn-user", Payload: "turn-user",
Signature: "turn-pass", Signature: "turn-pass",
} }
@@ -1177,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) {
}, },
}, },
} }
dnsCache := &DNSConfigCache{} dnsCache := &cache.DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
assert.NotNil(t, response) assert.NotNil(t, response)
// assert peer config // assert peer config
@@ -1289,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1369,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1517,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1566,7 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
} }
func Test_LoginPeer(t *testing.T) { func Test_LoginPeer(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true") t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
} }
@@ -1592,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1725,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) {
} }
func TestPeerAccountPeersUpdate(t *testing.T) { func TestPeerAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
require.NoError(t, err) require.NoError(t, err)
@@ -1782,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
var peer5 *nbpeer.Peer var peer5 *nbpeer.Peer
var peer6 *nbpeer.Peer var peer6 *nbpeer.Peer
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
t.Skip("Currently all updates will trigger a network map")
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
peerShouldNotReceiveUpdate(t, updMsg) peerShouldNotReceiveUpdate(t, updMsg)
@@ -1890,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}) })
t.Run("validator requires no update", func(t *testing.T) { t.Run("validator requires no update", func(t *testing.T) {
t.Skip("Currently all updates will trigger a network map")
requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil return update, false, nil
} }
@@ -2091,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
} }
func Test_DeletePeer(t *testing.T) { func Test_DeletePeer(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -2188,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) {
} }
func Test_AddPeer(t *testing.T) { func Test_AddPeer(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -2276,136 +2297,8 @@ func Test_AddPeer(t *testing.T) {
assert.Equal(t, uint64(totalPeers), account.Network.Serial) assert.Equal(t, uint64(totalPeers), account.Network.Serial)
} }
func TestBufferUpdateAccountPeers(t *testing.T) {
const (
peersCount = 1000
updateAccountInterval = 50 * time.Millisecond
)
var (
deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32
uapLastRun, dpLastRun atomic.Int64
totalNewRuns, totalOldRuns int
)
uap := func(ctx context.Context, accountID string) {
updatePeersDeleted.Store(deletedPeers.Load())
updatePeersRuns.Add(1)
uapLastRun.Store(time.Now().UnixMilli())
time.Sleep(100 * time.Millisecond)
}
t.Run("new approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{})
b := mu.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
uap(ctx, accountID)
if !b.update.Load() {
return
}
b.update.Store(false)
b.next = time.AfterFunc(updateAccountInterval, func() {
uap(ctx, accountID)
})
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalNewRuns = int(updatePeersRuns.Load())
})
t.Run("old approach", func(t *testing.T) {
updatePeersRuns.Store(0)
updatePeersDeleted.Store(0)
deletedPeers.Store(0)
var mustore sync.Map
bufupd := func(ctx context.Context, accountID string) {
mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{})
b := mu.(*sync.Mutex)
if !b.TryLock() {
return
}
go func() {
time.Sleep(updateAccountInterval)
b.Unlock()
uap(ctx, accountID)
}()
}
dp := func(ctx context.Context, accountID, peerID, userID string) error {
deletedPeers.Add(1)
dpLastRun.Store(time.Now().UnixMilli())
time.Sleep(10 * time.Millisecond)
bufupd(ctx, accountID)
return nil
}
am := mock_server.MockAccountManager{
UpdateAccountPeersFunc: uap,
BufferUpdateAccountPeersFunc: bufupd,
DeletePeerFunc: dp,
}
empty := ""
for range peersCount {
//nolint
am.DeletePeer(context.Background(), empty, empty, empty)
}
time.Sleep(100 * time.Millisecond)
assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted")
assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer")
assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer")
totalOldRuns = int(updatePeersRuns.Load())
})
assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns)
}
func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -2442,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
} }
func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -2476,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
} }
func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -2541,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
} }
func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -10,7 +10,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@@ -77,9 +76,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -123,9 +119,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -258,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin
return validIDs return validIDs
} }
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP,
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result[i] = fwRule
}
return result
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}

View File

@@ -1135,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
} }
func TestPolicyAccountPeersUpdate(t *testing.T) { func TestPolicyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{ g := []*types.Group{
{ {
@@ -1164,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
var policyWithGroupRulesNoPeers *types.Policy var policyWithGroupRulesNoPeers *types.Policy

View File

@@ -2,19 +2,15 @@ package server
import ( import (
"context" "context"
"errors"
"fmt"
"slices" "slices"
"github.com/rs/xid" "github.com/rs/xid"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -80,9 +76,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -139,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
} }
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
return nil, nil
}
for _, policy := range account.Policies {
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
continue
}
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
return nil, err
}
}
return maps.Values(peerPostureChecks), nil
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
@@ -214,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
return nil return nil
} }
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
}
if !isInGroup {
return nil
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
peerPostureChecks[sourcePostureCheckID] = postureCheck
}
return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
group := account.GetGroup(sourceGroup)
if group == nil {
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
}
if slices.Contains(group.Peers, peerID) {
return true, nil
}
}
}
return false, nil
}
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)

View File

@@ -21,7 +21,7 @@ const (
) )
func TestDefaultAccountManager_PostureCheck(t *testing.T) { func TestDefaultAccountManager_PostureCheck(t *testing.T) {
am, err := createManager(t) am, _, err := createManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
} }
func TestPostureCheckAccountPeersUpdate(t *testing.T) { func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
g := []*types.Group{ g := []*types.Group{
{ {
@@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
postureCheckA := &posture.Checks{ postureCheckA := &posture.Checks{
@@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture check to policy where destination has peers but source does not // Updating linked posture check to policy where destination has peers but source does not
// should trigger account peers update and send peer update // should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) updateManager.CloseChannel(context.Background(), peer2.ID)
}) })
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
@@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
} }
func TestArePostureCheckChangesAffectPeers(t *testing.T) { func TestArePostureCheckChangesAffectPeers(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
require.NoError(t, err, "failed to create account manager") require.NoError(t, err, "failed to create account manager")
account, err := initTestPostureChecksAccount(manager) account, err := initTestPostureChecksAccount(manager)

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -192,9 +191,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -249,9 +245,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers { if oldRouteAffectsPeers || newRouteAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -295,9 +288,6 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers { if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
} }
@@ -381,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID
return groupsMap, nil return groupsMap, nil
} }
func toProtocolRoute(route *route.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
// getPlaceholderIP returns a placeholder IP address for the route if domains are used // getPlaceholderIP returns a placeholder IP address for the route if domains are used
func getPlaceholderIP() netip.Prefix { func getPlaceholderIP() netip.Prefix {
// Using an IP from the documentation range to minimize impact in case older clients try to set a route // Using an IP from the documentation range to minimize impact in case older clients try to set a route
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
} }
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
// areRouteChangesAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers. // if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {

View File

@@ -14,6 +14,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -432,7 +434,7 @@ func TestCreateRoute(t *testing.T) {
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
am, err := createRouterManager(t) am, _, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -922,7 +924,7 @@ func TestSaveRoute(t *testing.T) {
} }
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
am, err := createRouterManager(t) am, _, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -1024,7 +1026,7 @@ func TestDeleteRoute(t *testing.T) {
Enabled: true, Enabled: true,
} }
am, err := createRouterManager(t) am, _, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -1071,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
AccessControlGroups: []string{routeGroup1}, AccessControlGroups: []string{routeGroup1},
} }
am, err := createRouterManager(t) am, _, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -1163,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
AccessControlGroups: []string{routeGroup1}, AccessControlGroups: []string{routeGroup1},
} }
am, err := createRouterManager(t) am, _, err := createRouterManager(t)
if err != nil { if err != nil {
t.Error("failed to create account manager") t.Error("failed to create account manager")
} }
@@ -1250,11 +1252,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1")
} }
func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper() t.Helper()
store, err := createRouterStore(t) store, err := createRouterStore(t)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
@@ -1285,7 +1287,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
return am, updateManager, nil
} }
func createRouterStore(t *testing.T) (store.Store, error) { func createRouterStore(t *testing.T) (store.Store, error) {
@@ -1948,7 +1959,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi
} }
func TestRouteAccountPeersUpdate(t *testing.T) { func TestRouteAccountPeersUpdate(t *testing.T) {
manager, err := createRouterManager(t) manager, updateManager, err := createRouterManager(t)
require.NoError(t, err, "failed to create account manager") require.NoError(t, err, "failed to create account manager")
account, err := initTestRouteAccount(t, manager) account, err := initTestRouteAccount(t, manager)
@@ -1976,9 +1987,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
require.NoError(t, err, "failed to create group %s", group.Name) require.NoError(t, err, "failed to create group %s", group.Name)
} }
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) updMsg := updateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) updateManager.CloseChannel(context.Background(), peer1ID)
}) })
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update

View File

@@ -18,7 +18,7 @@ import (
) )
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
} }
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
} }
func TestGetSetupKeys(t *testing.T) { func TestGetSetupKeys(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) {
} }
func TestSetupKeyAccountPeersUpdate(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA", ID: "groupA",
@@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
var setupKey *types.SetupKey var setupKey *types.SetupKey
@@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
} }
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -965,7 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if err != nil { if err != nil {
return err return err
} }
dnsDomain := am.GetDNSDomain(settings) dnsDomain := am.networkMapController.GetDNSDomain(settings)
var peerIDs []string var peerIDs []string
for _, peer := range peers { for _, peer := range peers {
@@ -992,16 +992,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
activity.PeerLoginExpired, peer.EventMeta(dnsDomain), activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
) )
if am.experimentalNetworkMap(accountID) { am.networkMapController.OnPeerUpdated(accountID, peer)
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
} }
if len(peerIDs) != 0 { if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service // this will trigger peer disconnect from the management service
log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID)
am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.networkMapController.DisconnectPeers(ctx, peerIDs)
am.BufferUpdateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
@@ -1115,6 +1112,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var addPeerRemovedEvents []func() var addPeerRemovedEvents []func()
var updateAccountPeers bool var updateAccountPeers bool
var userPeers []*nbpeer.Peer
var targetUser *types.User var targetUser *types.User
var err error var err error
@@ -1124,7 +1122,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return fmt.Errorf("failed to get user to delete: %w", err) return fmt.Errorf("failed to get user to delete: %w", err)
} }
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user peers: %w", err) return fmt.Errorf("failed to get user peers: %w", err)
} }
@@ -1147,6 +1145,17 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
return false, err return false, err
} }
for _, peer := range userPeers {
err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err)
}
if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err)
}
}
for _, addPeerRemovedEvent := range addPeerRemovedEvents { for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent() addPeerRemovedEvent()
} }

View File

@@ -1161,7 +1161,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
} }
func TestDefaultAccountManager_SaveUser(t *testing.T) { func TestDefaultAccountManager_SaveUser(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@@ -1333,7 +1333,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
func TestUserAccountPeersUpdate(t *testing.T) { func TestUserAccountPeersUpdate(t *testing.T) {
// account groups propagation is enabled // account groups propagation is enabled
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{
ID: "groupA", ID: "groupA",
@@ -1357,9 +1357,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
require.NoError(t, err) require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) updMsg := updateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) updateManager.CloseChannel(context.Background(), peer1.ID)
}) })
// Creating a new regular user should not update account peers and not send peer update // Creating a new regular user should not update account peers and not send peer update
@@ -1468,9 +1468,9 @@ func TestUserAccountPeersUpdate(t *testing.T) {
} }
}) })
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID)
t.Cleanup(func() { t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) updateManager.CloseChannel(context.Background(), peer4.ID)
}) })
// deleting user with linked peers should update account peers and send peer update // deleting user with linked peers should update account peers and send peer update
@@ -1748,7 +1748,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions {
} }
func TestApproveUser(t *testing.T) { func TestApproveUser(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -1807,7 +1807,7 @@ func TestApproveUser(t *testing.T) {
} }
func TestRejectUser(t *testing.T) { func TestRejectUser(t *testing.T) {
manager, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -18,6 +18,9 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@@ -68,7 +71,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
} }
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
@@ -111,15 +113,19 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(&types.ExtraSettings{}, nil). Return(&types.ExtraSettings{}, nil).
AnyTimes() AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }