Compare commits

...

6 Commits

Author SHA1 Message Date
jnfrati
59569c5147 remove logic from handler towards setup service 2026-04-27 17:12:39 +02:00
jnfrati
0c71aca86d Merge branch 'main' of github.com:netbirdio/netbird into feat/pat-creation-on-setup 2026-04-27 16:59:12 +02:00
jnfrati
be9f1b46e6 enable pat creation on setup 2026-04-27 16:54:47 +02:00
Vlad
154b81645a [management] removed legacy network map code (#5565) 2026-04-27 16:02:54 +02:00
Maycon Santos
34167c8a16 [misc] Update release pipeline version (#5995) 2026-04-27 10:55:38 +02:00
Maycon Santos
d6f08e4840 [misc] Update sign pipeline version (#5981) 2026-04-24 13:13:27 +02:00
26 changed files with 1659 additions and 5422 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.2"
SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -7,7 +7,6 @@ import (
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -16,11 +15,9 @@ import (
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -58,13 +55,6 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil
}
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err
}
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {

View File

@@ -12,9 +12,6 @@ import (
)
const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0"

View File

@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -1171,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t)
}
@@ -1231,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
@@ -1274,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t)
}
@@ -1332,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t)
}
@@ -1397,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
@@ -1633,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
}
func TestAccount_GetRoutesToSync(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
if err != nil {
t.Fatal(err)
}
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-2": {
ID: "route-2",
Network: prefix2,
NetID: "network-2",
Description: "network-2",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-3": {
ID: "route-3",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
func TestAccount_Copy(t *testing.T) {
account := &types.Account{
Id: "account1",
@@ -1824,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
AccountID: "account1",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)

View File

@@ -62,9 +62,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
apiPrefix = "/api"
)
const apiPrefix = "/api"
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
@@ -141,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddEndpoints(instanceManager, accountManager, router)
instance.AddVersionEndpoint(instanceManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)

View File

@@ -7,6 +7,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -15,13 +16,15 @@ import (
// handler handles the instance setup HTTP endpoints
type handler struct {
instanceManager nbinstance.Manager
setupManager *nbinstance.SetupService
}
// AddEndpoints registers the instance setup endpoints.
// These endpoints bypass authentication for initial setup.
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
setupManager: nbinstance.NewSetupService(instanceManager, accountManager),
}
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
@@ -55,24 +58,35 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
// setup creates the initial admin user for the instance.
// This endpoint is unauthenticated but only works when setup is required.
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req api.SetupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{
CreatePAT: req.CreatePat != nil && *req.CreatePat,
PATExpireInDays: req.PatExpireIn,
})
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email)
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
UserId: userData.ID,
Email: userData.Email,
})
resp := api.SetupResponse{
UserId: result.User.ID,
Email: result.User.Email,
}
if result.PATPlainToken != "" {
resp.PersonalAccessToken = &result.PATPlainToken
}
util.WriteJSONObject(ctx, w, resp)
}
// getVersionInfo returns version information for NetBird components.

View File

@@ -10,12 +10,18 @@ import (
"net/mail"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/mock_server"
nbstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -25,6 +31,7 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
rollbackSetupFn func(ctx context.Context, userID string) error
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
@@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error {
if m.rollbackSetupFn != nil {
return m.rollbackSetupFn(ctx, userID)
}
return nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
@@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
return setupTestRouterWithPAT(manager, nil)
}
func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router {
router := mux.NewRouter()
AddEndpoints(manager, router)
AddEndpoints(manager, accountManager, router)
return router
}
@@ -293,6 +311,222 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false")
manager := &mockInstanceManager{isSetupRequired: true}
// NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Nil(t, response.PersonalAccessToken)
}
func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Nil(t, response.PersonalAccessToken)
}
func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
createCalled := false
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
createCalled = true
return &idp.UserData{ID: "u1", Email: email, Name: name}, nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "u1", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "u1", initiator)
assert.Equal(t, "u1", target)
assert.Equal(t, "setup-token", name)
assert.Equal(t, 1, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.True(t, createCalled)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
require.NotNil(t, response.PersonalAccessToken)
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
}
func TestSetup_PAT_ExpireOutOfRange(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_PAT_Success(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
}
gotAccountArgs := struct {
userID string
email string
}{}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
gotAccountArgs.userID = userAuth.UserId
gotAccountArgs.email = userAuth.Email
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiator)
assert.Equal(t, "owner-id", target)
assert.Equal(t, "setup-token", name)
assert.Equal(t, 30, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Equal(t, "owner-id", response.UserId)
require.NotNil(t, response.PersonalAccessToken)
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
assert.Equal(t, "owner-id", gotAccountArgs.userID)
}
func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
rolledBackFor := ""
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "", errors.New("db down")
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id")
}
func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) {
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rolledBackFor := ""
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, status.Errorf(status.Internal, "token store unavailable")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
}
router := setupTestRouterWithPAT(manager, accountMgr)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails")
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()

View File

@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -12,6 +12,7 @@ import (
"sync"
"time"
"github.com/dexidp/dex/storage"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
@@ -60,6 +61,13 @@ type Manager interface {
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
// RollbackSetup reverses a successful CreateOwnerUser by deleting the user
// from the embedded IDP and reloading setupRequired from persistent state, so
// /api/setup can be retried only when no accounts or local users remain. Used
// when post-user steps (account or PAT creation) fail and the caller wants a
// clean slate.
RollbackSetup(ctx context.Context, userID string) error
// GetVersionInfo returns version information for NetBird components.
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
@@ -70,6 +78,7 @@ type instanceStore interface {
type embeddedIdP interface {
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
DeleteUser(ctx context.Context, userID string) error
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
}
@@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n
return userData, nil
}
// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the
// embedded IDP and reloads setupRequired from persistent state.
func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error {
if m.embeddedIdpManager == nil {
return errors.New("embedded IDP is not enabled")
}
var deleteErr error
if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil {
if isNotFoundError(err) {
log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID)
} else {
deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err)
}
}
if err := m.loadSetupRequired(ctx); err != nil {
reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err)
if deleteErr != nil {
return errors.Join(deleteErr, reloadErr)
}
return reloadErr
}
if deleteErr != nil {
return deleteErr
}
log.WithContext(ctx).Infof("rolled back setup for user %s", userID)
return nil
}
func isNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, storage.ErrNotFound) {
return true
}
if s, ok := status.FromError(err); ok {
return s.Type() == status.NotFound
}
return false
}
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {

View File

@@ -10,16 +10,19 @@ import (
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/shared/management/status"
)
type mockIdP struct {
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
users map[string][]*idp.UserData
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
deleteUserFunc func(ctx context.Context, userID string) error
users map[string][]*idp.UserData
getAllAccountsErr error
}
@@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
}
func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error {
if m.deleteUserFunc != nil {
return m.deleteUserFunc(ctx, userID)
}
return nil
}
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
assert.False(t, required)
}
func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) {
tests := []struct {
name string
err error
}{
{
name: "management status not found",
err: status.NewUserNotFoundError("owner-id"),
},
{
name: "dex storage not found",
err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, userID string) error {
assert.Equal(t, "owner-id", userID)
return tt.err
},
}
mgr := newTestManager(idpMock, &mockStore{})
mgr.setupRequired = false
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.NoError(t, err)
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts or local users remain")
})
}
}
func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, _ string) error {
return status.NewUserNotFoundError("owner-id")
},
}
mgr := newTestManager(idpMock, &mockStore{accountsCount: 1})
mgr.setupRequired = true
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.NoError(t, err)
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required while an account still exists")
}
func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, _ string) error {
return errors.New("idp unavailable")
},
}
mgr := newTestManager(idpMock, &mockStore{})
mgr.setupRequired = false
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "idp unavailable")
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup state should be reloaded even when user deletion fails")
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{setupRequired: true}

View File

@@ -0,0 +1,185 @@
package instance
import (
"context"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
setupPATTokenName = "setup-token"
// SetupPATEnabledEnvKey enables setup-time Personal Access Token creation.
SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED"
setupPATDefaultExpireDays = 1
// patMinExpireDays and patMaxExpireDays mirror the bounds enforced by
// DefaultAccountManager.CreatePAT in management/server/user.go. They are
// duplicated here so /api/setup can reject invalid input before it creates
// the embedded-IdP user.
patMinExpireDays = 1
patMaxExpireDays = 365
)
// SetupOptions controls optional work performed during initial instance setup.
type SetupOptions struct {
// CreatePAT requests creation of a setup Personal Access Token. It is honored
// only when SetupPATEnabledEnvKey is set to "true".
CreatePAT bool
// PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT
// creation is enabled.
PATExpireInDays *int
}
// SetupResult contains resources created during initial instance setup.
type SetupResult struct {
User *idp.UserData
PATPlainToken string
}
// SetupService orchestrates the initial setup use case across the instance and
// account bounded contexts and owns the compensation logic when a later step
// fails.
type SetupService struct {
instanceManager Manager
accountManager account.Manager
setupPATEnabled bool
}
// NewSetupService creates a setup use-case service.
func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService {
return &SetupService{
instanceManager: instanceManager,
accountManager: accountManager,
setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true",
}
}
func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) {
if !opts.CreatePAT {
return opts, nil
}
if !setupPATEnabled {
opts.CreatePAT = false
opts.PATExpireInDays = nil
return opts, nil
}
if opts.PATExpireInDays == nil {
defaultExpireInDays := setupPATDefaultExpireDays
opts.PATExpireInDays = &defaultExpireInDays
}
if *opts.PATExpireInDays < patMinExpireDays || *opts.PATExpireInDays > patMaxExpireDays {
return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", patMinExpireDays, patMaxExpireDays)
}
return opts, nil
}
// SetupOwner creates the initial owner user and, when requested and enabled by
// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access
// Token. If account or PAT provisioning fails, created resources are rolled
// back so setup can be retried.
func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) {
opts, err := normalizeSetupOptions(opts, m.setupPATEnabled)
if err != nil {
return nil, err
}
userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name)
if err != nil {
return nil, err
}
result := &SetupResult{User: userData}
if !opts.CreatePAT {
return result, nil
}
if m.accountManager == nil {
err := fmt.Errorf("account manager is required to create setup PAT")
m.rollbackSetup(ctx, userData.ID, "setup PAT requested without account manager", err, "")
return nil, err
}
userAuth := auth.UserAuth{
UserId: userData.ID,
Email: userData.Email,
Name: userData.Name,
}
accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth)
if err != nil {
err = fmt.Errorf("create account for setup user: %w", err)
m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, "")
return nil, err
}
pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays)
if err != nil {
err = fmt.Errorf("create setup PAT: %w", err)
m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID)
return nil, err
}
result.PATPlainToken = pat.PlainToken
return result, nil
}
func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) {
if accountID != "" {
if err := m.rollbackSetupAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, err)
} else {
log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr)
}
}
if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil {
log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, err)
return
}
log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr)
}
// rollbackSetupAccount removes only the setup-created account data from the
// store. It intentionally avoids accountManager.DeleteAccount because the normal
// account deletion path also deletes users from the IdP; embedded IdP cleanup is
// owned by instanceManager.RollbackSetup.
func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error {
if m.accountManager == nil {
return fmt.Errorf("account manager is required to roll back setup account")
}
accountStore := m.accountManager.GetStore()
if accountStore == nil {
return fmt.Errorf("account store is unavailable")
}
account, err := accountStore.GetAccount(ctx, accountID)
if err != nil {
if isNotFoundError(err) {
return nil
}
return fmt.Errorf("get setup account for rollback: %w", err)
}
if err := accountStore.DeleteAccount(ctx, account); err != nil {
if isNotFoundError(err) {
return nil
}
return fmt.Errorf("delete setup account for rollback: %w", err)
}
return nil
}

View File

@@ -0,0 +1,241 @@
package instance
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/mock_server"
nbstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
)
type setupInstanceManagerMock struct {
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
rollbackSetupFn func(ctx context.Context, userID string) error
}
func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) {
return true, nil
}
func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createOwnerUserFn != nil {
return m.createOwnerUserFn(ctx, email, password, name)
}
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
}
func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error {
if m.rollbackSetupFn != nil {
return m.rollbackSetupFn(ctx, userID)
}
return nil
}
func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) {
return nil, nil
}
var _ Manager = (*setupInstanceManagerMock)(nil)
func intPtr(v int) *int {
return &v
}
func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "false")
createCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
createCalls++
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
},
&mock_server.MockAccountManager{},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
})
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, "owner-id", result.User.ID)
assert.Empty(t, result.PATPlainToken)
assert.Equal(t, 1, createCalls)
}
func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
createCalled := false
setupManager := NewSetupService(
&setupInstanceManagerMock{
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
createCalled = true
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "owner-id", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiatorUserID)
assert.Equal(t, "owner-id", targetUserID)
assert.Equal(t, setupPATTokenName, tokenName)
assert.Equal(t, setupPATDefaultExpireDays, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
})
require.NoError(t, err)
require.NotNil(t, result)
assert.True(t, createCalled)
assert.Equal(t, "nbp_plain", result.PATPlainToken)
}
func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rollbackCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rollbackCalls++
assert.Equal(t, "owner-id", userID)
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "owner-id", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiatorUserID)
assert.Equal(t, "owner-id", targetUserID)
assert.Equal(t, setupPATTokenName, tokenName)
assert.Equal(t, 30, expiresIn)
return nil, status.Errorf(status.Internal, "token store unavailable")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, 1, rollbackCalls)
}
func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1"))
rolledBackFor := ""
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, errors.New("token failure")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, "owner-id", rolledBackFor)
}
func TestSetupOwner_CreatePATFails_AccountRollbackFailureStillRollsBackUser(t *testing.T) {
t.Setenv(SetupPATEnabledEnvKey, "true")
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed"))
rolledBackFor := ""
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, errors.New("token failure")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: intPtr(30),
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, "owner-id", rolledBackFor)
}

View File

@@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
testGetNetworkMapGeneral(t)
}
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) {
manager, _, err := createManager(t)
if err != nil {
@@ -1016,11 +1011,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t)
}
@@ -1600,7 +1590,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}

View File

@@ -2,10 +2,8 @@ package server
import (
"context"
"fmt"
"net"
"net/netip"
"sort"
"testing"
"time"
@@ -1840,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
},
}
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("check applied policies for the route", func(t *testing.T) {
route1 := account.Routes["route1"]
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
@@ -1858,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
assert.Len(t, policies, 0)
})
t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 4)
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "route1:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "route1:peerA",
},
}
additionalFirewallRule := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "tcp",
Port: 80,
RouteID: "route4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "all",
RouteID: "route4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2)
for _, rule := range expectedRoutesFirewallRules {
rule.RouteID = "route1:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
assert.Len(t, routesFirewallRules, 3)
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: existingNetwork.String(),
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "route2",
},
{
SourceRanges: []string{"0.0.0.0/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
{
SourceRanges: []string{"::/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, routesFirewallRules, 0)
})
}
// orderList is a helper function to sort a list of strings
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
for _, rule := range ruleList {
sort.Strings(rule.SourceRanges)
}
return ruleList
}
func TestRouteAccountPeersUpdate(t *testing.T) {
@@ -2665,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
},
}
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("validate applied policies for different network resources", func(t *testing.T) {
// Test case: Resource1 is directly applied to the policy (policyResource1)
policies := account.GetPoliciesForNetworkResource("resource1")
@@ -2693,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
policies = account.GetPoliciesForNetworkResource("resource6")
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
})
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
resourcePoliciesMap := account.GetResourcePoliciesMap()
resourceRoutersMap := account.GetResourceRoutersMap()
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 4)
assert.Len(t, sourcePeers, 5)
expectedFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "resource2:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "resource2:peerA",
},
}
additionalFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "tcp",
Port: 80,
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
// peerD is also the routing peer for resource2
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 2)
for _, rule := range expectedFirewallRules {
rule.RouteID = "resource2:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
assert.Len(t, sourcePeers, 3)
// peerE is a single routing peer for resource1 and resource3
// PeerE should only receive rules for resource1 since resource3 has no applied policy
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 2)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: "10.10.10.0/24",
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "resource1:peerE",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, firewallRules, 0)
// peerL is the single routing peer for resource5
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 1)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.29.67/32"},
Action: "accept",
Destination: "10.12.12.1/32",
Protocol: "tcp",
Port: 8080,
RouteID: "resource5:peerL",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 0)
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 2)
})
}

View File

@@ -1196,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
account.InitOnce()
return &account, nil
}
@@ -1635,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
if sExtraIntegratedValidatorGroups.Valid {
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
}
account.InitOnce()
return &account, nil
}

View File

@@ -8,7 +8,6 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
@@ -27,7 +26,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -110,16 +108,9 @@ type Account struct {
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
ReverseProxyFreeDomainNonce string
}
func (a *Account) InitOnce() {
a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only
type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool
@@ -155,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
o.SignupFormPending == onboarding.SignupFormPending
}
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
// If the given is not a routing peer, then the lists are empty.
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peer := a.GetPeer(peerID)
if peer == nil {
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
if r.Enabled {
r.Peer = peer.Key
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
}
}
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
var routes []*route.Route
@@ -276,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
if _, ok := validatedPeersMap[peerID]; !ok {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
}
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
nm := &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: a.Network.Copy(),
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
}
}
return nm
}
func (a *Account) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer,
@@ -421,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers(
return peersToConnect
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := account.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) {
return true
}
}
return false
}
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
for _, peer := range account.Peers {
label, err := GetPeerHostLabel(peer.Name, peerLabels)
@@ -800,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.GetPeerGroups(peerID)
enabled := true
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
return enabled
}
func (a *Account) GetPeerGroups(peerID string) LookupMap {
groupList := make(LookupMap)
for groupID, group := range a.Groups {
@@ -941,8 +684,6 @@ func (a *Account) Copy() *Account {
NetworkResources: networkResources,
Services: services,
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
Domains: domains,
}
}
@@ -1304,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
return nil
}
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
for _, route := range enabledRoutes {
// If no access control groups are specified, accept all traffic.
if len(route.AccessControlGroups) == 0 {
defaultPermit := getDefaultPermit(route)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := a.getDistributionGroupsPeers(route)
for _, accessGroup := range route.AccessControlGroups {
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
@@ -1387,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
return distributionGroupPeers
}
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range route.Groups {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if route.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: route.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: route.Domains,
IsDynamic: route.IsDynamic(),
RouteID: route.ID,
}
rules = append(rules, &rule)
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
if route.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
// and returns a list of policies that have rules with destinations matching the specified groups.
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
@@ -1508,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
return resourcePolicies
}
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := routers[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
}
}
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
@@ -1658,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
return result
}
// getNetworkResourcesRoutes convert the network resources list to routes list.
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
resourceAppliedPolicies := resourcePolicies[resource.ID]
var routes []*route.Route
// distribute the resource routes only if there is policy applied to it
if len(resourceAppliedPolicies) > 0 {
peer := a.GetPeer(peerId)
if peer != nil {
routes = append(routes, resource.ToRoute(peer, router))
}
}
return routes
}
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
routers := make(map[string]map[string]*routerTypes.NetworkRouter)

View File

@@ -4,8 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"github.com/miekg/dns"
@@ -19,7 +17,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
@@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
require.Len(t, result, 0)
}
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
Enabled: true,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
Enabled: true,
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
account := getBasicAccountsWithResource()
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router2Id",
NetworkID: network1ID,
AccountID: accID,
Peer: "router2Id",
})
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
tests := []struct {
name string

View File

@@ -1,47 +0,0 @@
package types
import (
"context"
"sync"
)
type Holder struct {
mu sync.RWMutex
accounts map[string]*Account
}
func NewHolder() *Holder {
return &Holder{
accounts: make(map[string]*Account),
}
}
func (h *Holder) GetAccount(id string) *Account {
h.mu.RLock()
defer h.mu.RUnlock()
return h.accounts[id]
}
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
a := h.accounts[account.Id]
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
return
}
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(ctx, id)
if err != nil {
return nil, err
}
h.accounts[id] = account
return account, nil
}

View File

@@ -1,67 +0,0 @@
package types
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
if a.NetworkMapCache != nil {
return
}
a.nmapInitOnce.Do(func() {
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
})
}
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}
func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
}
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
}
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
}
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.UpdatePeer(peer)
}
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}

View File

@@ -1,592 +0,0 @@
package types
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
t.Fatal("GetPeerNetworkMapComponents returned nil")
}
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
if newNetworkMap == nil {
t.Fatal("CalculateNetworkMapFromComponents returned nil")
}
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
}
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
goldenDir := filepath.Join("testdata", "comparison")
err = os.MkdirAll(goldenDir, 0755)
require.NoError(t, err)
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
require.NoError(t, err, "error writing legacy golden file")
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
err = os.WriteFile(newGoldenPath, newJSON, 0644)
require.NoError(t, err, "error writing components golden file")
componentsPath := filepath.Join(goldenDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err, "error writing components golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s",
legacyGoldenPath, newGoldenPath)
t.Logf("✅ NetworkMaps are identical")
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
t.Logf(" Components NetworkMap: %s", newGoldenPath)
}
func normalizeAndSortNetworkMap(nm *NetworkMap) {
if nm == nil {
return
}
sort.Slice(nm.Peers, func(i, j int) bool {
return nm.Peers[i].ID < nm.Peers[j].ID
})
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
})
sort.Slice(nm.Routes, func(i, j int) bool {
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
})
sort.Slice(nm.FirewallRules, func(i, j int) bool {
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
}
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
}
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
}
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
}
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
})
for i := range nm.RoutesFirewallRules {
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
}
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
}
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
}
for k := 0; k < minLen; k++ {
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
}
}
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
}
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
}
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
}
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
}
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
})
if nm.DNSConfig.CustomZones != nil {
for i := range nm.DNSConfig.CustomZones {
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
})
}
}
if len(nm.DNSConfig.NameServerGroups) != 0 {
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
})
}
}
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
t.Helper()
if legacy.Network.Serial != current.Network.Serial {
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
}
if len(legacy.Peers) != len(current.Peers) {
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
}
legacyPeerIDs := make(map[string]bool)
for _, p := range legacy.Peers {
legacyPeerIDs[p.ID] = true
}
for _, p := range current.Peers {
if !legacyPeerIDs[p.ID] {
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
}
}
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
}
if len(legacy.FirewallRules) != len(current.FirewallRules) {
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
}
if len(legacy.Routes) != len(current.Routes) {
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
}
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
}
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
}
}
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60"
expiredPeerID = "peer-98"
offlinePeerID = "peer-99"
routingPeerID = "peer-95"
testAccountID = "account-comparison-test"
)
func createTestAccount() *Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
}
policies := []*Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
account := &Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Network: &Network{
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func BenchmarkLegacyNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
}
}
func BenchmarkComponentsNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func BenchmarkComponentsCreation(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
}
}
func BenchmarkCalculationFromComponents(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}

View File

@@ -19,8 +19,6 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
type NetworkMapComponents struct {
PeerID string

View File

@@ -0,0 +1,787 @@
package types_test
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap {
t.Helper()
return account.GetPeerNetworkMapFromComponents(
context.Background(),
peerID,
account.GetPeersCustomZone(context.Background(), "netbird.io"),
nil,
validatedPeers,
account.GetResourcePoliciesMap(),
account.GetResourceRoutersMap(),
nil,
account.GetActiveGroupUsers(),
)
}
func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} {
excludeSet := make(map[string]struct{}, len(excludePeerIDs))
for _, id := range excludePeerIDs {
excludeSet[id] = struct{}{}
}
validated := make(map[string]struct{}, len(account.Peers))
for id := range account.Peers {
if _, excluded := excludeSet[id]; !excluded {
validated[id] = struct{}{}
}
}
return validated
}
func peerIDs(peers []*nbpeer.Peer) []string {
ids := make([]string, len(peers))
for i, p := range peers {
ids[i] = p.ID
}
return ids
}
func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.NotNil(t, nm)
assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy")
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy")
assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself")
assert.Empty(t, nm.OfflinePeers, "no expired peers expected")
}
func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) {
account := createComponentTestAccount()
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-intra-src", Name: "src <-> src", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-src"},
}},
})
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy")
}
func TestNetworkMapComponents_FirewallRules(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated")
var hasAcceptAll bool
for _, rule := range nm.FirewallRules {
if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) {
hasAcceptAll = true
}
}
assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy")
}
func TestNetworkMapComponents_LoginExpiration(t *testing.T) {
account := createComponentTestAccount()
account.Settings.PeerLoginExpirationEnabled = true
account.Settings.PeerLoginExpiration = 1 * time.Hour
expiredTime := time.Now().Add(-2 * time.Hour)
account.Peers["peer-dst-1"].LoginExpirationEnabled = true
account.Peers["peer-dst-1"].LastLogin = &expiredTime
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers")
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers")
}
func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account, "peer-dst-1")
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded")
assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either")
}
func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account, "peer-src-1")
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map")
assert.Empty(t, nm.FirewallRules)
}
func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasResourceRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" {
hasResourceRoute = true
break
}
}
assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router")
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer")
}
func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
var hasResourceRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" {
hasResourceRoute = true
break
}
}
assert.True(t, hasResourceRoute, "router peer should receive network resource route")
assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource")
}
func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route")
}
}
func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.PostureChecks = []*posture.Checks{
{ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id,
SourcePostureChecks: []string{"pc-version"},
Rules: []*types.PolicyRule{{
ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-guarded"},
}},
})
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
})
t.Run("peer passes posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasGuardedRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.1.1/32" {
hasGuardedRoute = true
}
}
assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route")
})
t.Run("peer fails posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route")
}
})
}
func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.PostureChecks = []*posture.Checks{
{ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
}},
{ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}},
}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id,
SourcePostureChecks: []string{"pc-version", "pc-os"},
Rules: []*types.PolicyRule{{
ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-strict"},
}},
})
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-strict", Name: "Strict Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
})
t.Run("passes both posture checks", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
account.Peers["peer-src-1"].Meta.GoOS = "linux"
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var found bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.2.1/32" {
found = true
}
}
assert.True(t, found, "peer passing both checks should get resource route")
})
t.Run("fails version posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route")
}
})
t.Run("fails OS posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route")
}
})
}
func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
var resourceFWRules []*types.RouteFirewallRule
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "10.200.0.1/32" {
resourceFWRules = append(resourceFWRules, rule)
}
}
assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource")
var hasSourcePeerIP bool
for _, rule := range resourceFWRules {
for _, sr := range rule.SourceRanges {
if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" {
hasSourcePeerIP = true
}
}
}
assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs")
}
func TestNetworkMapComponents_DNSManagement(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
t.Run("peer in DNS-enabled group", func(t *testing.T) {
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled")
})
t.Run("peer in DNS-disabled group", func(t *testing.T) {
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled")
})
}
func TestNetworkMapComponents_NameServerGroups(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.True(t, nm.DNSConfig.ServiceEnable)
var hasNSGroup bool
for _, ns := range nm.DNSConfig.NameServerGroups {
if ns.ID == "ns-main" {
hasNSGroup = true
}
}
assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration")
}
func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-ha-1"] = &route.Route{
ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"),
Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
}
account.Routes["route-ha-2"] = &route.Route{
ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 200, AccountID: account.Id,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
haCount := 0
for _, r := range nm.Routes {
if r.Network.String() == "172.16.0.0/16" {
haCount++
}
}
assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)")
}
func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-acl"] = &route.Route{
ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-dst"},
PeerGroups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasFWRule bool
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "192.168.100.0/24" {
hasFWRule = true
}
}
assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups")
}
func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-open"] = &route.Route{
ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-src"},
PeerGroups: []string{"group-src"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasFWRule bool
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "10.99.0.0/16" {
hasFWRule = true
}
}
assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules")
}
func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Peers["peer-dst-1"].SSHEnabled = true
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-ssh", Name: "SSH to dst", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH")
}
func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, p := range account.Policies {
p.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers")
assert.Empty(t, nm.FirewallRules)
}
func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, r := range account.Routes {
r.Enabled = false
}
for _, r := range account.NetworkResources {
r.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Routes, "disabled routes should not appear in network map")
}
func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, r := range account.NetworkResources {
r.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes")
}
}
func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated)
nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy")
assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy")
}
func TestNetworkMapComponents_DropPolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-drop", Name: "Drop src->dst", Enabled: true,
Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP,
Ports: []string{"5432"},
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasDropRule bool
for _, rule := range nm.FirewallRules {
if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" {
hasDropRule = true
}
}
assert.True(t, hasDropRule, "drop policy should generate drop firewall rule")
}
func TestNetworkMapComponents_PortRangePolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0"
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-range", Name: "Range rule", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasRangeRule bool
for _, rule := range nm.FirewallRules {
if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 {
hasRangeRule = true
}
}
assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule")
}
func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32",
})
account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"},
Resources: []types.Resource{{ID: "resource-2"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-res2", Name: "Access Resource 2", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-2"},
}},
})
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
resourceRouteCount := 0
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" {
resourceRouteCount++
}
}
assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources")
}
func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com",
})
account.Groups["group-res-domain"] = &types.Group{
ID: "group-res-domain", Name: "Domain Resource Group",
Resources: []types.Resource{{ID: "resource-domain"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-domain", Name: "Access domain resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-domain"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasDomainRoute bool
for _, r := range nm.Routes {
if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" {
hasDomainRoute = true
}
}
assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource")
}
func TestNetworkMapComponents_NetworkEmpty(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "nonexistent-peer", validated)
assert.NotNil(t, nm)
assert.Empty(t, nm.Peers)
assert.Empty(t, nm.FirewallRules)
assert.NotNil(t, nm.Network)
}
func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-other", Name: "Other Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id,
})
account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group",
Resources: []types.Resource{{ID: "resource-other"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-other", Name: "Other resource access", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-other"},
}},
})
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources")
}
}
func createComponentTestAccount() *types.Account {
peers := map[string]*nbpeer.Peer{
"peer-src-1": {
ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-src-2": {
ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-dst-1": {
ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-router-1": {
ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
}
groups := map[string]*types.Group{
"group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}},
"group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}},
"group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}},
"group-res": {
ID: "group-res", Name: "Resource Group",
Resources: []types.Resource{{ID: "resource-1"}},
},
}
policies := []*types.Policy{
{
ID: "policy-base", Name: "Base connectivity", Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-base", Name: "Allow src <-> dst", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
},
{
ID: "policy-resource", Name: "Network resource access", Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-resource", Name: "Source -> Resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-1"},
}},
},
}
routes := map[route.ID]*route.Route{
"route-main": {
ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
Enabled: true, Metric: 100,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
},
}
users := map[string]*types.User{
"user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}},
"user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}},
}
account := &types.Account{
Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"ns-main": {
ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32",
},
},
Networks: []*networkTypes.Network{
{ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"},
},
Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}

View File

@@ -1,967 +0,0 @@
package types_test
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"slices"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
sshUsersGroupID = "group-ssh-users"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
policyIDSSH = "policy-ssh"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map.
expiredPeerID = "peer-98" // This peer will be online but with an expired session.
offlinePeerID = "peer-99" // This peer will be completely offline.
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
testAccountID = "account-golden-test"
userAdminID = "user-admin"
userDevID = "user-dev"
userOpsID = "user-ops"
)
func TestGetPeerNetworkMap_Golden(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
b.ResetTimer()
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder", func(b *testing.B) {
for range b.N {
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
for _, peerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeerIP := net.IP{100, 64, 1, 1}
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: newPeerIP,
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newPeerID] = newPeer
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = append(devGroup.Peers, newPeerID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newPeerID)
}
validatedPeersMap[newPeerID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerAddedIncremental(account, newPeerID)
require.NoError(t, err, "error adding peer to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: net.IP{100, 64, 1, 1},
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
}
account.Peers[newPeerID] = newPeer
account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID)
account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID)
validatedPeersMap[newPeerID] = struct{}{}
b.ResetTimer()
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerAddedIncremental(account, newRouterID)
require.NoError(t, err, "error adding router to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
b.ResetTimer()
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newRouterID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
delete(validatedPeersMap, deletedPeerID)
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerDeleted(account, deletedPeerID)
require.NoError(t, err, "error deleting peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
}
}
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedRouterID := "peer-75"
var affectedRoute *route.Route
for _, r := range account.Routes {
if r.PeerID == deletedRouterID {
affectedRoute = r
break
}
}
require.NotNil(t, affectedRoute, "Router peer should have a route")
for _, group := range account.Groups {
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
return id == deletedRouterID
})
}
for routeID, r := range account.Routes {
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
delete(account.Routes, routeID)
}
}
delete(account.Peers, deletedRouterID)
delete(validatedPeersMap, deletedRouterID)
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerDeleted(account, deletedRouterID)
require.NoError(t, err, "error deleting routing peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool {
return id == deletedPeerID
})
account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool {
return id == deletedPeerID
})
delete(validatedPeersMap, deletedPeerID)
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
b.ResetTimer()
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerDeleted(account, deletedPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
for _, peer := range networkMap.Peers {
if peer.Status != nil {
peer.Status.LastSeen = time.Time{}
}
peer.LastLogin = &time.Time{}
}
for _, peer := range networkMap.OfflinePeers {
if peer.Status != nil {
peer.Status.LastSeen = time.Time{}
}
peer.LastLogin = &time.Time{}
}
sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID })
sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID })
sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID })
sort.Slice(networkMap.FirewallRules, func(i, j int) bool {
r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j]
if r1.PeerIP != r2.PeerIP {
return r1.PeerIP < r2.PeerIP
}
if r1.Protocol != r2.Protocol {
return r1.Protocol < r2.Protocol
}
if r1.Direction != r2.Direction {
return r1.Direction < r2.Direction
}
if r1.Action != r2.Action {
return r1.Action < r2.Action
}
return r1.Port < r2.Port
})
sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool {
r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j]
if r1.RouteID != r2.RouteID {
return r1.RouteID < r2.RouteID
}
if r1.Action != r2.Action {
return r1.Action < r2.Action
}
if r1.Destination != r2.Destination {
return r1.Destination < r2.Destination
}
if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 {
if r1.SourceRanges[0] != r2.SourceRanges[0] {
return r1.SourceRanges[0] < r2.SourceRanges[0]
}
}
return r1.Port < r2.Port
})
for _, ranges := range networkMap.RoutesFirewallRules {
sort.Slice(ranges.SourceRanges, func(i, j int) bool {
return ranges.SourceRanges[i] < ranges.SourceRanges[j]
})
}
}
type networkMapJSON struct {
Peers []*nbpeer.Peer `json:"Peers"`
Network *types.Network `json:"Network"`
Routes []*route.Route `json:"Routes"`
DNSConfig dns.Config `json:"DNSConfig"`
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
EnableSSH bool `json:"EnableSSH"`
}
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
result := &networkMapJSON{
Peers: nm.Peers,
Network: nm.Network,
Routes: nm.Routes,
DNSConfig: nm.DNSConfig,
OfflinePeers: nm.OfflinePeers,
FirewallRules: nm.FirewallRules,
RoutesFirewallRules: nm.RoutesFirewallRules,
ForwardingRules: nm.ForwardingRules,
EnableSSH: nm.EnableSSH,
}
if len(nm.AuthorizedUsers) > 0 {
result.AuthorizedUsers = make(map[string][]string)
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
for localUser := range nm.AuthorizedUsers {
localUsers = append(localUsers, localUser)
}
sort.Strings(localUsers)
for _, localUser := range localUsers {
userIDs := nm.AuthorizedUsers[localUser]
sortedUserIDs := make([]string, 0, len(userIDs))
for userID := range userIDs {
sortedUserIDs = append(sortedUserIDs, userID)
}
sort.Strings(sortedUserIDs)
result.AuthorizedUsers[localUser] = sortedUserIDs
}
}
return result
}
func createTestAccountWithEntities() *types.Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*types.Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
}
policies := []*types.Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*types.PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
}},
},
{
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
users := map[string]*types.User{
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
}
account := &types.Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*dns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
time.Sleep(100 * time.Millisecond)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
t.Log("Update golden file with OnPeerAdded router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
}

File diff suppressed because it is too large Load Diff

View File

@@ -3425,6 +3425,17 @@ components:
description: Display name for the admin user (defaults to email if not provided)
type: string
example: Admin User
create_pat:
description: If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled.
type: boolean
example: true
pat_expire_in:
description: Expiration of the Personal Access Token in days. Defaults to 1 day when omitted.
type: integer
minimum: 1
maximum: 365
default: 1
example: 30
required:
- email
- password
@@ -3441,6 +3452,10 @@ components:
description: Email address of the created user
type: string
example: admin@example.com
personal_access_token:
description: Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server.
type: string
example: nbp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
required:
- user_id
- email
@@ -4979,7 +4994,10 @@ paths:
/api/setup:
post:
summary: Setup Instance
description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled).
description: |
Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled).
When the management server is started with `NB_SETUP_PAT_ENABLED=true` and the request includes `create_pat: true`, the endpoint also provisions the NetBird account for the new owner user and returns the plain text Personal Access Token in `personal_access_token`. The optional `pat_expire_in` value defaults to 1 day when omitted. If any post-user step fails the Dex user is rolled back and setup remains retryable.
tags: [ Instance ]
security: [ ]
requestBody:

View File

@@ -4294,6 +4294,9 @@ type SetupKeyRequest struct {
// SetupRequest Request to set up the initial admin user
type SetupRequest struct {
// CreatePat If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled.
CreatePat *bool `json:"create_pat,omitempty"`
// Email Email address for the admin user
Email string `json:"email"`
@@ -4302,6 +4305,9 @@ type SetupRequest struct {
// Password Password for the admin user (minimum 8 characters)
Password string `json:"password"`
// PatExpireIn Expiration of the Personal Access Token in days. Defaults to 1 day when omitted.
PatExpireIn *int `json:"pat_expire_in,omitempty"`
}
// SetupResponse Response after successful instance setup
@@ -4309,6 +4315,9 @@ type SetupResponse struct {
// Email Email address of the created user
Email string `json:"email"`
// PersonalAccessToken Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server.
PersonalAccessToken *string `json:"personal_access_token,omitempty"`
// UserId The ID of the created user
UserId string `json:"user_id"`
}