Compare commits

...

17 Commits

Author SHA1 Message Date
crn4
e6a663ba20 merge main 2026-04-27 18:03:11 +02:00
crn4
7c4011d8e2 remove withContext from all sql calls 2026-04-27 17:28:56 +02:00
Vlad
154b81645a [management] removed legacy network map code (#5565) 2026-04-27 16:02:54 +02:00
Maycon Santos
34167c8a16 [misc] Update release pipeline version (#5995) 2026-04-27 10:55:38 +02:00
Maycon Santos
d6f08e4840 [misc] Update sign pipeline version (#5981) 2026-04-24 13:13:27 +02:00
Zoltan Papp
f732b01a05 [management] unify peer-update test timeout via constant (#5952)
peerShouldReceiveUpdate waited 500ms for the expected update message,
and every outer wrapper across the management/server test suite paired
it with a 1s goroutine-drain timeout. Both were too tight for slower
CI runners (MySQL, FreeBSD, loaded sqlite), producing intermittent
"Timed out waiting for update message" failures in tests like
TestDNSAccountPeersUpdate, TestPeerAccountPeersUpdate, and
TestNameServerAccountPeersUpdate.

Introduce peerUpdateTimeout (5s) next to the helper and use it both in
the helper and in every outer wrapper so the two timeouts stay in sync.
Only runs down on failure; passing tests return as soon as the channel
delivers, so there is no slowdown on green runs.
2026-04-23 21:19:21 +02:00
crn4
8fe2b5ec1e removed condition on 1 yop per account 2026-04-21 15:12:52 +02:00
crn4
e62521132c Merge remote-tracking branch 'origin/main' into feat/byod-proxy
# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager.go
#	management/internals/modules/reverseproxy/proxy/manager/manager.go
#	management/internals/modules/reverseproxy/proxy/manager_mock.go
#	management/internals/shared/grpc/proxy.go
#	management/server/store/sql_store.go
#	proxy/management_integration_test.go
2026-04-13 17:02:24 +03:00
crn4
de3cb06067 added proxy id to cluster api response 2026-03-31 00:28:19 +02:00
crn4
4fdc39c8f8 review comments 2026-03-24 15:37:31 +01:00
crn4
94149a9441 linter 2026-03-24 14:58:03 +01:00
crn4
38fd73fad6 merge main 2026-03-24 14:50:03 +01:00
crn4
9dd76b5a07 merge main 2026-03-24 14:20:03 +01:00
crn4
0b5380a7dc review comments 2026-03-24 13:32:38 +01:00
crn4
177171e437 change api 2026-03-19 21:49:04 +01:00
crn4
da57b0f276 rename byod to byop 2026-03-19 16:11:57 +01:00
crn4
26ba03f08e [proxy] feature: bring your own proxy 2026-03-19 01:02:46 +01:00
53 changed files with 3171 additions and 5554 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.2" SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3" GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"

View File

@@ -7,7 +7,6 @@ import (
"os" "os"
"slices" "slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -16,11 +15,9 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -58,13 +55,6 @@ type Controller struct {
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
} }
type bufferUpdate struct { type bufferUpdate struct {
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err)) log.Fatal(fmt.Errorf("error creating metrics: %w", err))
} }
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{ return &Controller{
repo: newRepository(store), repo: newRepository(store),
metrics: nMetrics, metrics: nMetrics,
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController, proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager, EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
} }
} }
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var ( account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account *types.Account if err != nil {
err error return fmt.Errorf("failed to get account: %v", err)
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
} }
globalStart := time.Now() globalStart := time.Now()
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers() groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start)) c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now() start = time.Now()
var remotePeerNetworkMap *types.NetworkMap remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account. // UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID) return c.sendUpdateAccountPeers(ctx, accountID)
} }
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err return err
} }
var remotePeerNetworkMap *types.NetworkMap remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil return peer, emptyMap, nil, 0, nil
} }
var ( account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
account *types.Account if err != nil {
err error return nil, nil, nil, 0, err
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
} }
account.InjectProxyPolicies(ctx) account.InjectProxyPolicies(ctx)
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
} }
var networkMap *types.NetworkMap resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if c.experimentalNetworkMap(accountID) { groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil return peer, networkMap, postureChecks, dnsFwdPort, nil
} }
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain // GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string { func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil { if settings == nil {
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
} }
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) err := c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
} }
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID) return c.bufferSendUpdateAccountPeers(ctx, accountID)
} }
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap, MessageType: network_map.MessageTypeNetworkMap,
}) })
c.peersUpdateManager.CloseChannel(ctx, peerID) c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
} }
return c.bufferSendUpdateAccountPeers(ctx, accountID) return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err return nil, err
} }
var networkMap *types.NetworkMap account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
if c.experimentalNetworkMap(peer.AccountID) { routers := account.GetResourceRoutersMap()
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil) groupIDToUserIDs := account.GetActiveGroupUsers()
} else { networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {

View File

@@ -12,9 +12,6 @@ import (
) )
const ( const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0" DnsForwarderPortMinVersion = "v0.59.0"

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface { type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain var ret []*domain.Domain
// Add connected proxy clusters as free domains. // Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). // For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err return nil, err
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
// Verify the target cluster is in the available clusters // Verify the target cluster is in the available clusters for this account
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
} }
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
} }
@@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
} }
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
if len(byopAddresses) > 0 {
return byopAddresses, nil
}
return m.proxyManager.GetActiveClusterAddresses(ctx)
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) { func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := "" bestCluster := ""
bestLen := -1 bestLen := -1

View File

@@ -0,0 +1,110 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}

View File

@@ -11,15 +11,20 @@ import (
// Manager defines the interface for proxy operations // Manager defines the interface for proxy operations
type Manager interface { type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID string) error Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error) GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteProxy(ctx context.Context, proxyID string) error
} }
// OIDCValidationConfig contains the OIDC configuration needed for token validation. // OIDCValidationConfig contains the OIDC configuration needed for token validation.

View File

@@ -13,13 +13,19 @@ import (
// store defines the interface for proxy persistence operations // store defines the interface for proxy persistence operations
type store interface { type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error SaveProxy(ctx context.Context, p *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteProxy(ctx context.Context, proxyID string) error
} }
// Manager handles all proxy operations // Manager handles all proxy operations
@@ -43,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database. // Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them. // capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error { func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) error {
now := time.Now() now := time.Now()
var caps proxy.Capabilities var caps proxy.Capabilities
if capabilities != nil { if capabilities != nil {
@@ -53,9 +59,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
ID: proxyID, ID: proxyID,
ClusterAddress: clusterAddress, ClusterAddress: clusterAddress,
IPAddress: ipAddress, IPAddress: ipAddress,
AccountID: accountID,
LastSeen: now, LastSeen: now,
ConnectedAt: &now, ConnectedAt: &now,
Status: "connected", Status: proxy.StatusConnected,
Capabilities: caps, Capabilities: caps,
} }
@@ -74,16 +81,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
} }
// Disconnect marks a proxy as disconnected in the database // Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error { func (m *Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now() if err := m.store.DisconnectProxy(ctx, proxyID); err != nil {
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err return err
} }
@@ -96,7 +95,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
} }
// Heartbeat updates the proxy's last seen timestamp // Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { func (m *Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err return err
@@ -108,7 +107,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddre
} }
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies // GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
@@ -146,10 +145,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
} }
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration // CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err return err
} }
return nil return nil
} }
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
return nil, err
}
return addresses, nil
}
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
return m.store.GetProxyByAccountID(ctx, accountID)
}
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
return m.store.CountProxiesByAccountID(ctx, accountID)
}
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
if err != nil {
return false, err
}
return !conflicting, nil
}
func (m *Manager) DeleteProxy(ctx context.Context, proxyID string) error {
if err := m.store.DeleteProxy(ctx, proxyID); err != nil {
log.WithContext(ctx).Errorf("failed to delete proxy %s: %v", proxyID, err)
return err
}
return nil
}

View File

@@ -0,0 +1,334 @@
package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID string) error
updateProxyHeartbeatFunc func(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteProxyFunc func(ctx context.Context, proxyID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
if m.saveProxyFunc != nil {
return m.saveProxyFunc(ctx, p)
}
return nil
}
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
if m.disconnectProxyFunc != nil {
return m.disconnectProxyFunc(ctx, proxyID)
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, proxyID, clusterAddress, ipAddress)
}
return nil
}
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveProxyClusterAddressesFunc != nil {
return m.getActiveProxyClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveProxyClusterAddressesForAccFunc != nil {
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)
}
return nil
}
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
if m.getProxyByAccountIDFunc != nil {
return m.getProxyByAccountIDFunc(ctx, accountID)
}
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
if m.countProxiesByAccountIDFunc != nil {
return m.countProxiesByAccountIDFunc(ctx, accountID)
}
return 0, nil
}
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
if m.isClusterAddressConflictingFunc != nil {
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
}
return false, nil
}
func (m *mockStore) DeleteProxy(ctx context.Context, proxyID string) error {
if m.deleteProxyFunc != nil {
return m.deleteProxyFunc(ctx, proxyID)
}
return nil
}
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")
m, err := NewManager(s, meter)
if err != nil {
panic(err)
}
return m
}
func TestConnect_WithAccountID(t *testing.T) {
accountID := "acc-123"
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Equal(t, "proxy-1", savedProxy.ID)
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
assert.Equal(t, &accountID, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
assert.NotNil(t, savedProxy.ConnectedAt)
}
func TestConnect_WithoutAccountID(t *testing.T) {
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Nil(t, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
}
func TestConnect_StoreError(t *testing.T) {
s := &mockStore{
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil, nil)
assert.Error(t, err)
}
func TestIsClusterAddressAvailable(t *testing.T) {
tests := []struct {
name string
conflicting bool
storeErr error
wantResult bool
wantErr bool
}{
{
name: "available - no conflict",
conflicting: false,
wantResult: true,
},
{
name: "not available - conflict exists",
conflicting: true,
wantResult: false,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
return tt.conflicting, tt.storeErr
},
}
mgr := newTestManager(s)
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
})
}
}
func TestCountAccountProxies(t *testing.T) {
tests := []struct {
name string
count int64
storeErr error
wantCount int64
wantErr bool
}{
{
name: "no proxies",
count: 0,
wantCount: 0,
},
{
name: "one proxy",
count: 1,
wantCount: 1,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
return tt.count, tt.storeErr
},
}
mgr := newTestManager(s)
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
func TestGetAccountProxy(t *testing.T) {
accountID := "acc-123"
t.Run("found", func(t *testing.T) {
expected := &proxy.Proxy{
ID: "proxy-1",
ClusterAddress: "byop.example.com",
AccountID: &accountID,
Status: proxy.StatusConnected,
}
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
assert.Equal(t, accountID, accID)
return expected, nil
},
}
mgr := newTestManager(s)
p, err := mgr.GetAccountProxy(context.Background(), accountID)
require.NoError(t, err)
assert.Equal(t, expected, p)
})
t.Run("not found", func(t *testing.T) {
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, errors.New("not found")
},
}
mgr := newTestManager(s)
_, err := mgr.GetAccountProxy(context.Background(), accountID)
assert.Error(t, err)
})
}
func TestDeleteProxy(t *testing.T) {
t.Run("success", func(t *testing.T) {
var deletedID string
s := &mockStore{
deleteProxyFunc: func(_ context.Context, proxyID string) error {
deletedID = proxyID
return nil
},
}
mgr := newTestManager(s)
err := mgr.DeleteProxy(context.Background(), "proxy-1")
require.NoError(t, err)
assert.Equal(t, "proxy-1", deletedID)
})
t.Run("store error", func(t *testing.T) {
s := &mockStore{
deleteProxyFunc: func(_ context.Context, _ string) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.DeleteProxy(context.Background(), "proxy-1")
assert.Error(t, err)
})
}
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
expected := []string{"byop.example.com"}
s := &mockStore{
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return expected, nil
},
}
mgr := newTestManager(s)
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, expected, result)
}

View File

@@ -93,17 +93,17 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
} }
// Connect mocks base method. // Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error { func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities) ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Connect indicates an expected call of Connect. // Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities)
} }
// Disconnect mocks base method. // Disconnect mocks base method.
@@ -135,7 +135,19 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
} }
// GetActiveClusters mocks base method. func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx) ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
@@ -164,6 +176,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAdd
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
} }
// GetAccountProxy mocks base method.
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountProxy indicates an expected call of GetAccountProxy.
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
}
// CountAccountProxies mocks base method.
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAccountProxies indicates an expected call of CountAccountProxies.
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
}
// IsClusterAddressAvailable mocks base method.
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
}
// DeleteProxy mocks base method.
func (m *MockManager) DeleteProxy(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteProxy indicates an expected call of DeleteProxy.
func (mr *MockManagerMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockManager)(nil).DeleteProxy), ctx, proxyID)
}
// MockController is a mock of Controller interface. // MockController is a mock of Controller interface.
type MockController struct { type MockController struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@@ -1,6 +1,13 @@
package proxy package proxy
import "time" import (
"time"
)
const (
StatusConnected = "connected"
StatusDisconnected = "disconnected"
)
// Capabilities describes what a proxy can handle, as reported via gRPC. // Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability. // Nil fields mean the proxy never reported this capability.
@@ -20,6 +27,7 @@ type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"` ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"` IPAddress string `gorm:"type:varchar(45)"`
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time ConnectedAt *time.Time
DisconnectedAt *time.Time DisconnectedAt *time.Time
@@ -35,6 +43,7 @@ func (Proxy) TableName() string {
// Cluster represents a group of proxy nodes serving the same address. // Cluster represents a group of proxy nodes serving the same address.
type Cluster struct { type Cluster struct {
ID string
Address string Address string
ConnectedProxies int ConnectedProxies int
} }

View File

@@ -0,0 +1,195 @@
package proxytoken
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
store store.Store
permissionsManager permissions.Manager
}
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{store: s, permissionsManager: permissionsManager}
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
}
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
var req api.ProxyTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" || len(req.Name) > 255 {
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
return
}
var expiresIn time.Duration
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
if err != nil {
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
}
resp := make([]api.ProxyToken, 0, len(tokens))
for _, token := range tokens {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokenID := mux.Vars(r)["tokenId"]
if tokenID == "" {
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
resp := api.ProxyToken{
Id: token.ID,
Name: token.Name,
Revoked: token.Revoked,
}
if !token.CreatedAt.IsZero() {
resp.CreatedAt = token.CreatedAt
}
if token.ExpiresAt != nil {
resp.ExpiresAt = token.ExpiresAt
}
if token.LastUsed != nil {
resp.LastUsed = token.LastUsed
}
return resp
}
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
base := toProxyTokenResponse(&generated.ProxyAccessToken)
plainToken := string(generated.PlainToken)
return api.ProxyTokenCreated{
Id: base.Id,
Name: base.Name,
CreatedAt: base.CreatedAt,
ExpiresAt: base.ExpiresAt,
LastUsed: base.LastUsed,
Revoked: base.Revoked,
PlainToken: plainToken,
}
}

View File

@@ -0,0 +1,275 @@
package proxytoken
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func authContext(accountID, userID string) context.Context {
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
AccountId: accountID,
UserId: userID,
})
}
func TestCreateToken_AccountScoped(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "my-token"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp api.ProxyTokenCreated
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
assert.NotEmpty(t, resp.PlainToken)
assert.Equal(t, "my-token", resp.Name)
assert.False(t, resp.Revoked)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.AccountID)
assert.Equal(t, accountID, *savedToken.AccountID)
assert.Equal(t, "user-1", savedToken.CreatedBy)
}
func TestCreateToken_WithExpiration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "expiring-token", "expires_in": 3600}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.ExpiresAt)
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
}
func TestCreateToken_EmptyName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": ""}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestCreateToken_PermissionDenied(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": "test"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestListTokens(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
now := time.Now()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.listTokens(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp []api.ProxyToken
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
require.Len(t, resp, 2)
assert.Equal(t, "tok-1", resp[0].Id)
assert.False(t, resp[0].Revoked)
assert.Equal(t, "tok-2", resp[1].Id)
assert.True(t, resp[1].Revoked)
}
func TestRevokeToken_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
Name: "test-token",
AccountID: &accountID,
}, nil)
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext(accountID, "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRevokeToken_WrongAccount(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
otherAccount := "acc-other"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: &otherAccount,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRevokeToken_ManagementWideToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: nil,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}

View File

@@ -28,4 +28,5 @@ type Manager interface {
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context) StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
} }

View File

@@ -138,6 +138,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
} }
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method. // GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) { func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -195,6 +195,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters)) apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters { for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{ apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address, Address: c.Address,
ConnectedProxies: c.ConnectedProxies, ConnectedProxies: c.ConnectedProxies,
}) })

View File

@@ -984,6 +984,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
return services, nil return services, nil
} }
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {

View File

@@ -433,7 +433,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper() t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv return srv
} }
@@ -712,7 +712,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err) require.NoError(t, err)
@@ -1135,7 +1135,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err) require.NoError(t, err)

View File

@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController()) proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@@ -47,6 +48,11 @@ type ProxyOIDCConfig struct {
KeysLocation string KeysLocation string
} }
// ProxyTokenChecker checks whether a proxy access token is still valid.
type ProxyTokenChecker interface {
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
}
// ProxyServiceServer implements the ProxyService gRPC server // ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct { type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer proto.UnimplementedProxyServiceServer
@@ -75,6 +81,9 @@ type ProxyServiceServer struct {
// Store for one-time authentication tokens // Store for one-time authentication tokens
tokenStore *OneTimeTokenStore tokenStore *OneTimeTokenStore
// Checker for proxy access token validity
tokenChecker ProxyTokenChecker
// OIDC configuration for proxy authentication // OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig oidcConfig ProxyOIDCConfig
@@ -90,6 +99,8 @@ const pkceVerifierTTL = 10 * time.Minute
type proxyConnection struct { type proxyConnection struct {
proxyID string proxyID string
address string address string
accountID *string
tokenID string
capabilities *proto.ProxyCapabilities capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse sendChan chan *proto.GetMappingUpdateResponse
@@ -97,8 +108,19 @@ type proxyConnection struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
token := GetProxyTokenFromContext(ctx)
if token == nil || token.AccountID == nil {
return nil
}
if requestAccountID == "" || *token.AccountID != requestAccountID {
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
}
return nil
}
// NewProxyServiceServer creates a new proxy service server. // NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{ s := &ProxyServiceServer{
accessLogManager: accessLogMgr, accessLogManager: accessLogMgr,
@@ -108,6 +130,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
proxyManager: proxyMgr, proxyManager: proxyMgr,
tokenChecker: tokenChecker,
cancel: cancel, cancel: cancel,
} }
go s.cleanupStaleProxies(ctx) go s.cleanupStaleProxies(ctx)
@@ -166,10 +189,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid") return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
} }
var accountID *string
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
connCtx, cancel := context.WithCancel(ctx) connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{ conn := &proxyConnection{
proxyID: proxyID, proxyID: proxyID,
address: proxyAddress, address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(), capabilities: req.GetCapabilities(),
stream: stream, stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100), sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
@@ -177,12 +221,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel, cancel: cancel,
} }
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil { if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{ caps = &proxy.Capabilities{
@@ -191,19 +229,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
SupportsCrowdsec: c.SupportsCrowdsec, SupportsCrowdsec: c.SupportsCrowdsec,
} }
} }
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil { if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID, caps); err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err) if accountID != nil {
s.connectedProxies.Delete(proxyID) cancel()
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil { return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
} }
return status.Errorf(codes.Internal, "register proxy in database: %v", err) log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
} }
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"address": proxyAddress, "address": proxyAddress,
"cluster_addr": proxyAddress, "cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()), "total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster") }).Info("Proxy registered in cluster")
defer func() { defer func() {
@@ -228,7 +272,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
go s.sender(conn, errChan) go s.sender(conn, errChan)
// Start heartbeat goroutine // Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo) go s.heartbeat(connCtx, conn, peerInfo)
select { select {
case err := <-errChan: case err := <-errChan:
@@ -238,16 +282,28 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
} }
// heartbeat updates the proxy's last_seen timestamp every minute func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, ipAddress string) {
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil { if err := s.proxyManager.Heartbeat(ctx, conn.proxyID, conn.address, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", conn.proxyID, err)
}
if conn.tokenID != "" && s.tokenChecker != nil {
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
if err != nil {
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
continue
}
if !valid {
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
conn.cancel()
return
}
} }
case <-ctx.Done(): case <-ctx.Done():
return return
@@ -255,8 +311,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddr
} }
} }
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
if !isProxyAddressValid(conn.address) { if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid") return fmt.Errorf("proxy address is invalid")
@@ -289,7 +343,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
} }
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) { func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx) var services []*rpservice.Service
var err error
if conn.accountID != nil {
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
} else {
services, err = s.serviceManager.GetGlobalServices(ctx)
}
if err != nil { if err != nil {
return nil, fmt.Errorf("get services from store: %w", err) return nil, fmt.Errorf("get services from store: %w", err)
} }
@@ -318,8 +378,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
return mappings, nil return mappings, nil
} }
// isProxyAddressValid validates a proxy address // isProxyAddressValid validates a proxy address (domain name or IP address)
func isProxyAddressValid(addr string) bool { func isProxyAddressValid(addr string) bool {
if addr == "" {
return false
}
if net.ParseIP(addr) != nil {
return true
}
_, err := domain.ValidateDomains([]string{addr}) _, err := domain.ValidateDomains([]string{addr})
return err == nil return err == nil
} }
@@ -343,6 +409,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) { func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog() accessLog := req.GetLog()
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
return nil, err
}
fields := log.Fields{ fields := log.Fields{
"service_id": accessLog.GetServiceId(), "service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(), "account_id": accessLog.GetAccountId(),
@@ -380,11 +450,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed. // Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per // For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management. // proxy so that every replica can independently authenticate with management.
// BYOP proxies only receive updates for their own account's services.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers") log.Debugf("Broadcasting service update to all connected proxy servers")
updateAccountIDs := make(map[string]struct{})
for _, m := range update.Mapping {
if m.AccountId != "" {
updateAccountIDs[m.AccountId] = struct{}{}
}
}
s.connectedProxies.Range(func(key, value interface{}) bool { s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection) conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID) connUpdate := update
if conn.accountID != nil && len(updateAccountIDs) > 0 {
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
return true
}
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
if len(filtered) == 0 {
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil { if resp == nil {
return true return true
} }
@@ -398,6 +489,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
}) })
} }
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
conn.cancel()
s.connectedProxies.Delete(proxyID)
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
}
}
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
var filtered []*proto.ProxyMapping
for _, m := range mappings {
if m.AccountId == accountID {
filtered = append(filtered, m)
}
}
return filtered
}
// GetConnectedProxies returns a list of connected proxy IDs // GetConnectedProxies returns a list of connected proxy IDs
func (s *ProxyServiceServer) GetConnectedProxies() []string { func (s *ProxyServiceServer) GetConnectedProxies() []string {
var proxies []string var proxies []string
@@ -466,6 +577,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
continue continue
} }
conn := connVal.(*proxyConnection) conn := connVal.(*proxyConnection)
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
continue
}
if !proxyAcceptsMapping(conn, update) { if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id) log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue continue
@@ -549,6 +663,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
} }
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err) log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
@@ -668,6 +786,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
// SendStatusUpdate handles status updates from proxy clients. // SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) { func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
accountID := req.GetAccountId() accountID := req.GetAccountId()
serviceID := req.GetServiceId() serviceID := req.GetServiceId()
protoStatus := req.GetStatus() protoStatus := req.GetStatus()
@@ -738,6 +860,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
// CreateProxyPeer handles proxy peer creation with one-time token authentication // CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) { func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
serviceID := req.GetServiceId() serviceID := req.GetServiceId()
accountID := req.GetAccountId() accountID := req.GetAccountId()
token := req.GetToken() token := req.GetToken()
@@ -792,6 +918,10 @@ func strPtr(s string) *string {
} }
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) { func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
redirectURL, err := url.Parse(req.GetRedirectUrl()) redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
@@ -920,21 +1050,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user. // GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key service, err := s.getServiceByDomain(ctx, domain)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("get services: %w", err) return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
}
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
} }
if service.SessionPrivateKey == "" { if service.SessionPrivateKey == "" {
@@ -1032,6 +1150,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil }, nil
} }
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey) pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil { if err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@@ -1115,18 +1237,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
} }
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx) return s.serviceManager.GetServiceByDomain(ctx, domain)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain: %s", domain)
} }
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {

View File

@@ -0,0 +1,29 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsProxyAddressValid(t *testing.T) {
tests := []struct {
name string
addr string
valid bool
}{
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
{name: "valid IPv6", addr: "::1", valid: true},
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
{name: "empty string", addr: "", valid: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
})
}
}

View File

@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
return nil, status.Errorf(codes.Unauthenticated, "invalid token") return nil, status.Errorf(codes.Unauthenticated, "invalid token")
} }
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() { if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked") return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
} }

View File

@@ -91,6 +91,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
if m.err != nil {
return nil, m.err
}
for _, services := range m.proxiesByAccount {
for _, svc := range services {
if svc.Domain == domain {
return svc, nil
}
}
}
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil return nil, nil
} }

View File

@@ -12,9 +12,12 @@ import (
cachestore "github.com/eko/gocache/lib/v4/store" cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -313,6 +316,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
assert.Contains(t, err.Error(), "invalid state format") assert.Contains(t, err.Error(), "invalid state format")
} }
func scopedCtx(accountID string) context.Context {
token := &types.ProxyAccessToken{
ID: "token-1",
AccountID: &accountID,
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func globalCtx() context.Context {
token := &types.ProxyAccessToken{
ID: "token-global",
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
assert.NoError(t, err)
}
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
err := enforceAccountScope(globalCtx(), "acc-1")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "acc-2")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "")
assert.NoError(t, err)
}
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
err := enforceAccountScope(context.Background(), "acc-1")
assert.NoError(t, err)
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) { func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background() ctx := context.Background()
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager) proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore) createTestProxies(t, ctx, testStore)
@@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil return nil, nil
} }
type testValidateSessionProxyManager struct{} type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
return nil return nil
} }
@@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
return nil, nil return nil, nil
} }
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil return nil, nil
} }
@@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
return nil return nil
} }
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil return nil
} }

View File

@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }
@@ -1171,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
} }
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t) testAccountManager_NetworkUpdates_SaveGroup(t)
} }
@@ -1231,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t) testAccountManager_NetworkUpdates_DeletePolicy(t)
} }
@@ -1274,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t) testAccountManager_NetworkUpdates_SavePolicy(t)
} }
@@ -1332,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t) testAccountManager_NetworkUpdates_DeletePeer(t)
} }
@@ -1397,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t) testAccountManager_NetworkUpdates_DeleteGroup(t)
} }
@@ -1633,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-2"))
} }
func TestAccount_GetRoutesToSync(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
if err != nil {
t.Fatal(err)
}
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-2": {
ID: "route-2",
Network: prefix2,
NetID: "network-2",
Description: "network-2",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-3": {
ID: "route-3",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
func TestAccount_Copy(t *testing.T) { func TestAccount_Copy(t *testing.T) {
account := &types.Account{ account := &types.Account{
Id: "account1", Id: "account1",
@@ -1824,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
AccountID: "account1", AccountID: "account1",
}, },
}, },
NetworkMapCache: &types.NetworkMapBuilder{},
} }
account.InitOnce()
err := hasNilField(account) err := hasNilField(account)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -3170,7 +3074,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err return nil, nil, err
} }
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -3253,6 +3157,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
return manager, updateManager, account, peer1, peer2, peer3 return manager, updateManager, account, peer1, peer2, peer3
} }
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
// wrappers wait for an expected update message. Sized for slow CI runners
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
// seconds. Only runs down on failure; passing tests return immediately
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper() t.Helper()
select { select {
@@ -3271,7 +3182,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
if msg == nil { if msg == nil {
t.Errorf("Received nil update message, expected valid message") t.Errorf("Received nil update message, expected valid message")
} }
case <-time.After(500 * time.Millisecond): case <-time.After(peerUpdateTimeout):
t.Error("Timed out waiting for update message") t.Error("Timed out waiting for update message")
} }
} }

View File

@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -146,6 +147,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if serviceManager != nil && reverseProxyDomainManager != nil { if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
} }
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
// Register OAuth callback handler for proxy authentication // Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil { if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies) oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)

View File

@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers()) netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
} }

View File

@@ -216,6 +216,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
nil, nil,
usersManager, usersManager,
nil, nil,
nil,
) )
proxyService.SetServiceManager(&testServiceManager{store: testStore}) proxyService.SetServiceManager(&testServiceManager{store: testStore})
@@ -435,6 +436,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
func (m *testServiceManager) StartExposeReaper(_ context.Context) {} func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil return nil, nil
} }

View File

@@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err) t.Fatalf("Failed to create proxy manager: %v", err)
} }
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil { if err != nil {
@@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err) t.Fatalf("Failed to create proxy manager: %v", err)
} }
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil { if err != nil {

View File

@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
testGetNetworkMapGeneral(t) testGetNetworkMapGeneral(t)
} }
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) { func testGetNetworkMapGeneral(t *testing.T) {
manager, _, err := createManager(t) manager, _, err := createManager(t)
if err != nil { if err != nil {
@@ -1016,11 +1011,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
} }
} }
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) { func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t) testUpdateAccountPeers(t)
} }
@@ -1600,7 +1590,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
} }
func Test_LoginPeer(t *testing.T) { func Test_LoginPeer(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet") t.Skip("The SQLite store is not properly supported by Windows yet")
} }
@@ -1907,7 +1896,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1929,7 +1918,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1994,7 +1983,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2012,7 +2001,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2058,7 +2047,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2076,7 +2065,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2113,7 +2102,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2131,7 +2120,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -2,10 +2,8 @@ package server
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip" "net/netip"
"sort"
"testing" "testing"
"time" "time"
@@ -1840,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
}, },
} }
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("check applied policies for the route", func(t *testing.T) { t.Run("check applied policies for the route", func(t *testing.T) {
route1 := account.Routes["route1"] route1 := account.Routes["route1"]
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
@@ -1858,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
assert.Len(t, policies, 0) assert.Len(t, policies, 0)
}) })
t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 4)
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "route1:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "route1:peerA",
},
}
additionalFirewallRule := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "tcp",
Port: 80,
RouteID: "route4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "all",
RouteID: "route4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2)
for _, rule := range expectedRoutesFirewallRules {
rule.RouteID = "route1:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
assert.Len(t, routesFirewallRules, 3)
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: existingNetwork.String(),
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "route2",
},
{
SourceRanges: []string{"0.0.0.0/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
{
SourceRanges: []string{"::/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, routesFirewallRules, 0)
})
}
// orderList is a helper function to sort a list of strings
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
for _, rule := range ruleList {
sort.Strings(rule.SourceRanges)
}
return ruleList
} }
func TestRouteAccountPeersUpdate(t *testing.T) { func TestRouteAccountPeersUpdate(t *testing.T) {
@@ -2070,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
@@ -2107,7 +1990,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2127,7 +2010,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2145,7 +2028,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2185,7 +2068,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2225,7 +2108,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -2665,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
}, },
} }
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("validate applied policies for different network resources", func(t *testing.T) { t.Run("validate applied policies for different network resources", func(t *testing.T) {
// Test case: Resource1 is directly applied to the policy (policyResource1) // Test case: Resource1 is directly applied to the policy (policyResource1)
policies := account.GetPoliciesForNetworkResource("resource1") policies := account.GetPoliciesForNetworkResource("resource1")
@@ -2693,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
policies = account.GetPoliciesForNetworkResource("resource6") policies = account.GetPoliciesForNetworkResource("resource6")
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
}) })
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
resourcePoliciesMap := account.GetResourcePoliciesMap()
resourceRoutersMap := account.GetResourceRoutersMap()
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 4)
assert.Len(t, sourcePeers, 5)
expectedFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "resource2:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "resource2:peerA",
},
}
additionalFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "tcp",
Port: 80,
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
// peerD is also the routing peer for resource2
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 2)
for _, rule := range expectedFirewallRules {
rule.RouteID = "resource2:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
assert.Len(t, sourcePeers, 3)
// peerE is a single routing peer for resource1 and resource3
// PeerE should only receive rules for resource1 since resource3 has no applied policy
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 2)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: "10.10.10.0/24",
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "resource1:peerE",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, firewallRules, 0)
// peerL is the single routing peer for resource5
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 1)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.29.67/32"},
Action: "accept",
Destination: "10.12.12.1/32",
Protocol: "tcp",
Port: 8080,
RouteID: "resource5:peerL",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 0)
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 2)
})
} }

View File

@@ -1196,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
account.NameServerGroups[ns.ID] = &ns account.NameServerGroups[ns.ID] = &ns
} }
account.NameServerGroupsG = nil account.NameServerGroupsG = nil
account.InitOnce()
return &account, nil return &account, nil
} }
@@ -1635,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
if sExtraIntegratedValidatorGroups.Valid { if sExtraIntegratedValidatorGroups.Valid {
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
} }
account.InitOnce()
return &account, nil return &account, nil
} }
@@ -4497,6 +4495,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
return nil return nil
} }
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var tokens []*types.ProxyAccessToken
result := tx.Where("account_id = ?", accountID).Find(&tokens)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
}
return tokens, nil
}
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
if err != nil {
return false, err
}
return token.IsValid(), nil
}
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var token types.ProxyAccessToken
result := tx.Take(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy access token not found")
}
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
}
return &token, nil
}
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token. // MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error { func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.Model(&types.ProxyAccessToken{}). result := s.db.Model(&types.ProxyAccessToken{}).
@@ -5439,13 +5478,29 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
return nil return nil
} }
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error {
now := time.Now()
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ?", proxyID).
Updates(map[string]interface{}{
"status": proxy.StatusDisconnected,
"disconnected_at": now,
"last_seen": now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to disconnect proxy")
}
return nil
}
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now() now := time.Now()
result := s.db. result := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected"). Where("id = ? AND status = ?", proxyID, proxy.StatusConnected).
Update("last_seen", now) Update("last_seen", now)
if result.Error != nil { if result.Error != nil {
@@ -5477,7 +5532,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
result := s.db. result := s.db.
Model(&proxy.Proxy{}). Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address"). Distinct("cluster_address").
Pluck("cluster_address", &addresses) Pluck("cluster_address", &addresses)
@@ -5489,13 +5544,72 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
return addresses, nil return addresses, nil
} }
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
var addresses []string
result := s.db.
Model(&proxy.Proxy{}).
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
}
return addresses, nil
}
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
var p proxy.Proxy
result := s.db.Where("account_id = ?", accountID).Take(&p)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "proxy not found for account")
}
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
}
return &p, nil
}
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
var count int64
result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
if result.Error != nil {
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
}
return count, nil
}
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
var count int64
result := s.db.
Model(&proxy.Proxy{}).
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
Count(&count)
if result.Error != nil {
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
}
return count > 0, nil
}
func (s *SqlStore) DeleteProxy(ctx context.Context, proxyID string) error {
result := s.db.Where(idQueryCondition, proxyID).Delete(&proxy.Proxy{})
if result.Error != nil {
return status.Errorf(status.Internal, "delete proxy: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "proxy not found")
}
return nil
}
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
var clusters []proxy.Cluster var clusters []proxy.Cluster
result := s.db.Model(&proxy.Proxy{}). result := s.db.Model(&proxy.Proxy{}).
Select("cluster_address as address, COUNT(*) as connected_proxies"). Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)). Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
Group("cluster_address"). Group("cluster_address").
Scan(&clusters) Scan(&clusters)

View File

@@ -114,6 +114,9 @@ type Store interface {
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
RevokeProxyAccessToken(ctx context.Context, tokenID string) error RevokeProxyAccessToken(ctx context.Context, tokenID string) error
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
@@ -284,13 +287,19 @@ type Store interface {
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteProxy(ctx context.Context, proxyID string) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
@@ -494,6 +503,9 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error { func(db *gorm.DB) error {
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key") return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
}, },
func(db *gorm.DB) error {
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
},
} }
} }

View File

@@ -165,19 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
} }
// GetClusterSupportsCrowdSec mocks base method.
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// Close mocks base method. // Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error { func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -236,6 +223,21 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
} }
// CountProxiesByAccountID mocks base method.
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
}
// CreateAccessLog mocks base method. // CreateAccessLog mocks base method.
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error { func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -574,6 +576,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
} }
// DeleteProxy mocks base method.
func (m *MockStore) DeleteProxy(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteProxy", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteProxy indicates an expected call of DeleteProxy.
func (mr *MockStoreMockRecorder) DeleteProxy(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteProxy", reflect.TypeOf((*MockStore)(nil).DeleteProxy), ctx, proxyID)
}
// DeleteRoute mocks base method. // DeleteRoute mocks base method.
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error { func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -714,6 +730,20 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
} }
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// DisconnectProxy indicates an expected call of DisconnectProxy.
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID)
}
// EphemeralServiceExists mocks base method. // EphemeralServiceExists mocks base method.
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) { func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1300,6 +1330,21 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
} }
// GetActiveProxyClusterAddressesForAccount mocks base method.
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
}
// GetActiveProxyClusters mocks base method. // GetActiveProxyClusters mocks base method.
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1388,6 +1433,20 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
} }
// GetClusterSupportsCrowdSec mocks base method.
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// GetClusterSupportsCustomPorts mocks base method. // GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1957,6 +2016,51 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
} }
// GetProxyAccessTokenByID mocks base method.
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
ret0, _ := ret[0].(*types2.ProxyAccessToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
}
// GetProxyAccessTokensByAccountID mocks base method.
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
}
// GetProxyByAccountID mocks base method.
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
ret0, _ := ret[0].(*proxy.Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
}
// GetResourceGroups mocks base method. // GetResourceGroups mocks base method.
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) { func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2389,6 +2493,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
} }
// IsClusterAddressConflicting mocks base method.
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
}
// IsPrimaryAccount mocks base method. // IsPrimaryAccount mocks base method.
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2405,6 +2524,21 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID)
} }
// IsProxyAccessTokenValid mocks base method.
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
}
// ListCustomDomains mocks base method. // ListCustomDomains mocks base method.
func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) { func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -8,7 +8,6 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@@ -27,7 +26,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -110,16 +108,9 @@ type Account struct {
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
ReverseProxyFreeDomainNonce string ReverseProxyFreeDomainNonce string
} }
func (a *Account) InitOnce() {
a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only // this class is used by gorm only
type PrimaryAccountInfo struct { type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool IsDomainPrimaryAccount bool
@@ -155,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
o.SignupFormPending == onboarding.SignupFormPending o.SignupFormPending == onboarding.SignupFormPending
} }
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
// If the given is not a routing peer, then the lists are empty.
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peer := a.GetPeer(peerID)
if peer == nil {
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
if r.Enabled {
r.Peer = peer.Key
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
}
}
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
var routes []*route.Route var routes []*route.Route
@@ -276,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID] return a.Groups[groupID]
} }
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
if _, ok := validatedPeersMap[peerID]; !ok {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
}
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
nm := &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: a.Network.Copy(),
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
}
}
return nm
}
func (a *Account) addNetworksRoutingPeers( func (a *Account) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route, networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer, peer *nbpeer.Peer,
@@ -421,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers(
return peersToConnect return peersToConnect
} }
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := account.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) {
return true
}
}
return false
}
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
for _, peer := range account.Peers { for _, peer := range account.Peers {
label, err := GetPeerHostLabel(peer.Name, peerLabels) label, err := GetPeerHostLabel(peer.Name, peerLabels)
@@ -800,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps return grps
} }
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.GetPeerGroups(peerID)
enabled := true
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
return enabled
}
func (a *Account) GetPeerGroups(peerID string) LookupMap { func (a *Account) GetPeerGroups(peerID string) LookupMap {
groupList := make(LookupMap) groupList := make(LookupMap)
for groupID, group := range a.Groups { for groupID, group := range a.Groups {
@@ -941,8 +684,6 @@ func (a *Account) Copy() *Account {
NetworkResources: networkResources, NetworkResources: networkResources,
Services: services, Services: services,
Onboarding: a.Onboarding, Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
Domains: domains, Domains: domains,
} }
} }
@@ -1304,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
return nil return nil
} }
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
for _, route := range enabledRoutes {
// If no access control groups are specified, accept all traffic.
if len(route.AccessControlGroups) == 0 {
defaultPermit := getDefaultPermit(route)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := a.getDistributionGroupsPeers(route)
for _, accessGroup := range route.AccessControlGroups {
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule var fwRules []*RouteFirewallRule
for _, policy := range policies { for _, policy := range policies {
@@ -1387,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
return distributionGroupPeers return distributionGroupPeers
} }
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range route.Groups {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if route.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: route.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: route.Domains,
IsDynamic: route.IsDynamic(),
RouteID: route.ID,
}
rules = append(rules, &rule)
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
if route.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups // GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
// and returns a list of policies that have rules with destinations matching the specified groups. // and returns a list of policies that have rules with destinations matching the specified groups.
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
@@ -1508,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
return resourcePolicies return resourcePolicies
} }
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := routers[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
}
}
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups { for _, groupID := range groups {
@@ -1658,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
return result return result
} }
// getNetworkResourcesRoutes convert the network resources list to routes list.
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
resourceAppliedPolicies := resourcePolicies[resource.ID]
var routes []*route.Route
// distribute the resource routes only if there is policy applied to it
if len(resourceAppliedPolicies) > 0 {
peer := a.GetPeer(peerId)
if peer != nil {
routes = append(routes, resource.ToRoute(peer, router))
}
}
return routes
}
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
routers := make(map[string]map[string]*routerTypes.NetworkRouter) routers := make(map[string]map[string]*routerTypes.NetworkRouter)

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"slices"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -19,7 +17,6 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
require.Len(t, result, 0) require.Len(t, result, 0)
} }
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
Enabled: true,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
Enabled: true,
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
account := getBasicAccountsWithResource()
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router2Id",
NetworkID: network1ID,
AccountID: accID,
Peer: "router2Id",
})
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) { func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -1,47 +0,0 @@
package types
import (
"context"
"sync"
)
type Holder struct {
mu sync.RWMutex
accounts map[string]*Account
}
func NewHolder() *Holder {
return &Holder{
accounts: make(map[string]*Account),
}
}
func (h *Holder) GetAccount(id string) *Account {
h.mu.RLock()
defer h.mu.RUnlock()
return h.accounts[id]
}
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
a := h.accounts[account.Id]
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
return
}
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(ctx, id)
if err != nil {
return nil, err
}
h.accounts[id] = account
return account, nil
}

View File

@@ -1,67 +0,0 @@
package types
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
if a.NetworkMapCache != nil {
return
}
a.nmapInitOnce.Do(func() {
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
})
}
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}
func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
}
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
}
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
}
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.UpdatePeer(peer)
}
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}

View File

@@ -1,592 +0,0 @@
package types
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
t.Fatal("GetPeerNetworkMapComponents returned nil")
}
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
if newNetworkMap == nil {
t.Fatal("CalculateNetworkMapFromComponents returned nil")
}
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
}
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
goldenDir := filepath.Join("testdata", "comparison")
err = os.MkdirAll(goldenDir, 0755)
require.NoError(t, err)
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
require.NoError(t, err, "error writing legacy golden file")
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
err = os.WriteFile(newGoldenPath, newJSON, 0644)
require.NoError(t, err, "error writing components golden file")
componentsPath := filepath.Join(goldenDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err, "error writing components golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s",
legacyGoldenPath, newGoldenPath)
t.Logf("✅ NetworkMaps are identical")
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
t.Logf(" Components NetworkMap: %s", newGoldenPath)
}
func normalizeAndSortNetworkMap(nm *NetworkMap) {
if nm == nil {
return
}
sort.Slice(nm.Peers, func(i, j int) bool {
return nm.Peers[i].ID < nm.Peers[j].ID
})
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
})
sort.Slice(nm.Routes, func(i, j int) bool {
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
})
sort.Slice(nm.FirewallRules, func(i, j int) bool {
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
}
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
}
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
}
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
}
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
})
for i := range nm.RoutesFirewallRules {
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
}
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
}
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
}
for k := 0; k < minLen; k++ {
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
}
}
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
}
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
}
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
}
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
}
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
})
if nm.DNSConfig.CustomZones != nil {
for i := range nm.DNSConfig.CustomZones {
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
})
}
}
if len(nm.DNSConfig.NameServerGroups) != 0 {
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
})
}
}
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
t.Helper()
if legacy.Network.Serial != current.Network.Serial {
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
}
if len(legacy.Peers) != len(current.Peers) {
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
}
legacyPeerIDs := make(map[string]bool)
for _, p := range legacy.Peers {
legacyPeerIDs[p.ID] = true
}
for _, p := range current.Peers {
if !legacyPeerIDs[p.ID] {
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
}
}
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
}
if len(legacy.FirewallRules) != len(current.FirewallRules) {
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
}
if len(legacy.Routes) != len(current.Routes) {
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
}
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
}
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
}
}
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60"
expiredPeerID = "peer-98"
offlinePeerID = "peer-99"
routingPeerID = "peer-95"
testAccountID = "account-comparison-test"
)
func createTestAccount() *Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
}
policies := []*Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
account := &Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Network: &Network{
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func BenchmarkLegacyNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
}
}
func BenchmarkComponentsNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func BenchmarkComponentsCreation(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
}
}
func BenchmarkCalculationFromComponents(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}

View File

@@ -19,8 +19,6 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
type NetworkMapComponents struct { type NetworkMapComponents struct {
PeerID string PeerID string

View File

@@ -0,0 +1,787 @@
package types_test
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap {
t.Helper()
return account.GetPeerNetworkMapFromComponents(
context.Background(),
peerID,
account.GetPeersCustomZone(context.Background(), "netbird.io"),
nil,
validatedPeers,
account.GetResourcePoliciesMap(),
account.GetResourceRoutersMap(),
nil,
account.GetActiveGroupUsers(),
)
}
func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} {
excludeSet := make(map[string]struct{}, len(excludePeerIDs))
for _, id := range excludePeerIDs {
excludeSet[id] = struct{}{}
}
validated := make(map[string]struct{}, len(account.Peers))
for id := range account.Peers {
if _, excluded := excludeSet[id]; !excluded {
validated[id] = struct{}{}
}
}
return validated
}
func peerIDs(peers []*nbpeer.Peer) []string {
ids := make([]string, len(peers))
for i, p := range peers {
ids[i] = p.ID
}
return ids
}
func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.NotNil(t, nm)
assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy")
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy")
assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself")
assert.Empty(t, nm.OfflinePeers, "no expired peers expected")
}
func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) {
account := createComponentTestAccount()
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-intra-src", Name: "src <-> src", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-src"},
}},
})
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy")
}
func TestNetworkMapComponents_FirewallRules(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated")
var hasAcceptAll bool
for _, rule := range nm.FirewallRules {
if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) {
hasAcceptAll = true
}
}
assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy")
}
func TestNetworkMapComponents_LoginExpiration(t *testing.T) {
account := createComponentTestAccount()
account.Settings.PeerLoginExpirationEnabled = true
account.Settings.PeerLoginExpiration = 1 * time.Hour
expiredTime := time.Now().Add(-2 * time.Hour)
account.Peers["peer-dst-1"].LoginExpirationEnabled = true
account.Peers["peer-dst-1"].LastLogin = &expiredTime
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers")
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers")
}
func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account, "peer-dst-1")
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded")
assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either")
}
func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account, "peer-src-1")
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map")
assert.Empty(t, nm.FirewallRules)
}
func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasResourceRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" {
hasResourceRoute = true
break
}
}
assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router")
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer")
}
func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
var hasResourceRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" {
hasResourceRoute = true
break
}
}
assert.True(t, hasResourceRoute, "router peer should receive network resource route")
assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource")
}
func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route")
}
}
func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.PostureChecks = []*posture.Checks{
{ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id,
SourcePostureChecks: []string{"pc-version"},
Rules: []*types.PolicyRule{{
ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-guarded"},
}},
})
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
})
t.Run("peer passes posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasGuardedRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.1.1/32" {
hasGuardedRoute = true
}
}
assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route")
})
t.Run("peer fails posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route")
}
})
}
func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.PostureChecks = []*posture.Checks{
{ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
}},
{ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}},
}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id,
SourcePostureChecks: []string{"pc-version", "pc-os"},
Rules: []*types.PolicyRule{{
ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-strict"},
}},
})
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-strict", Name: "Strict Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
})
t.Run("passes both posture checks", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
account.Peers["peer-src-1"].Meta.GoOS = "linux"
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var found bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.2.1/32" {
found = true
}
}
assert.True(t, found, "peer passing both checks should get resource route")
})
t.Run("fails version posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route")
}
})
t.Run("fails OS posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route")
}
})
}
func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
var resourceFWRules []*types.RouteFirewallRule
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "10.200.0.1/32" {
resourceFWRules = append(resourceFWRules, rule)
}
}
assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource")
var hasSourcePeerIP bool
for _, rule := range resourceFWRules {
for _, sr := range rule.SourceRanges {
if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" {
hasSourcePeerIP = true
}
}
}
assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs")
}
func TestNetworkMapComponents_DNSManagement(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
t.Run("peer in DNS-enabled group", func(t *testing.T) {
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled")
})
t.Run("peer in DNS-disabled group", func(t *testing.T) {
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled")
})
}
func TestNetworkMapComponents_NameServerGroups(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.True(t, nm.DNSConfig.ServiceEnable)
var hasNSGroup bool
for _, ns := range nm.DNSConfig.NameServerGroups {
if ns.ID == "ns-main" {
hasNSGroup = true
}
}
assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration")
}
func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-ha-1"] = &route.Route{
ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"),
Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
}
account.Routes["route-ha-2"] = &route.Route{
ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 200, AccountID: account.Id,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
haCount := 0
for _, r := range nm.Routes {
if r.Network.String() == "172.16.0.0/16" {
haCount++
}
}
assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)")
}
func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-acl"] = &route.Route{
ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-dst"},
PeerGroups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasFWRule bool
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "192.168.100.0/24" {
hasFWRule = true
}
}
assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups")
}
func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-open"] = &route.Route{
ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-src"},
PeerGroups: []string{"group-src"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasFWRule bool
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "10.99.0.0/16" {
hasFWRule = true
}
}
assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules")
}
func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Peers["peer-dst-1"].SSHEnabled = true
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-ssh", Name: "SSH to dst", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH")
}
func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, p := range account.Policies {
p.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers")
assert.Empty(t, nm.FirewallRules)
}
func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, r := range account.Routes {
r.Enabled = false
}
for _, r := range account.NetworkResources {
r.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Routes, "disabled routes should not appear in network map")
}
func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, r := range account.NetworkResources {
r.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes")
}
}
func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated)
nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy")
assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy")
}
func TestNetworkMapComponents_DropPolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-drop", Name: "Drop src->dst", Enabled: true,
Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP,
Ports: []string{"5432"},
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasDropRule bool
for _, rule := range nm.FirewallRules {
if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" {
hasDropRule = true
}
}
assert.True(t, hasDropRule, "drop policy should generate drop firewall rule")
}
func TestNetworkMapComponents_PortRangePolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0"
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-range", Name: "Range rule", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasRangeRule bool
for _, rule := range nm.FirewallRules {
if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 {
hasRangeRule = true
}
}
assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule")
}
func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32",
})
account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"},
Resources: []types.Resource{{ID: "resource-2"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-res2", Name: "Access Resource 2", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-2"},
}},
})
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
resourceRouteCount := 0
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" {
resourceRouteCount++
}
}
assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources")
}
func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com",
})
account.Groups["group-res-domain"] = &types.Group{
ID: "group-res-domain", Name: "Domain Resource Group",
Resources: []types.Resource{{ID: "resource-domain"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-domain", Name: "Access domain resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-domain"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasDomainRoute bool
for _, r := range nm.Routes {
if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" {
hasDomainRoute = true
}
}
assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource")
}
func TestNetworkMapComponents_NetworkEmpty(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "nonexistent-peer", validated)
assert.NotNil(t, nm)
assert.Empty(t, nm.Peers)
assert.Empty(t, nm.FirewallRules)
assert.NotNil(t, nm.Network)
}
func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-other", Name: "Other Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id,
})
account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group",
Resources: []types.Resource{{ID: "resource-other"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-other", Name: "Other resource access", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-other"},
}},
})
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources")
}
}
func createComponentTestAccount() *types.Account {
peers := map[string]*nbpeer.Peer{
"peer-src-1": {
ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-src-2": {
ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-dst-1": {
ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-router-1": {
ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
}
groups := map[string]*types.Group{
"group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}},
"group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}},
"group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}},
"group-res": {
ID: "group-res", Name: "Resource Group",
Resources: []types.Resource{{ID: "resource-1"}},
},
}
policies := []*types.Policy{
{
ID: "policy-base", Name: "Base connectivity", Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-base", Name: "Allow src <-> dst", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
},
{
ID: "policy-resource", Name: "Network resource access", Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-resource", Name: "Source -> Resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-1"},
}},
},
}
routes := map[route.ID]*route.Route{
"route-main": {
ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
Enabled: true, Metric: 100,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
},
}
users := map[string]*types.User{
"user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}},
"user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}},
}
account := &types.Account{
Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"ns-main": {
ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32",
},
},
Networks: []*networkTypes.Network{
{ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"},
},
Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}

View File

@@ -1,967 +0,0 @@
package types_test
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"slices"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
sshUsersGroupID = "group-ssh-users"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
policyIDSSH = "policy-ssh"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map.
expiredPeerID = "peer-98" // This peer will be online but with an expired session.
offlinePeerID = "peer-99" // This peer will be completely offline.
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
testAccountID = "account-golden-test"
userAdminID = "user-admin"
userDevID = "user-dev"
userOpsID = "user-ops"
)
func TestGetPeerNetworkMap_Golden(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
b.ResetTimer()
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder", func(b *testing.B) {
for range b.N {
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
for _, peerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeerIP := net.IP{100, 64, 1, 1}
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: newPeerIP,
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newPeerID] = newPeer
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = append(devGroup.Peers, newPeerID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newPeerID)
}
validatedPeersMap[newPeerID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerAddedIncremental(account, newPeerID)
require.NoError(t, err, "error adding peer to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: net.IP{100, 64, 1, 1},
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
}
account.Peers[newPeerID] = newPeer
account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID)
account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID)
validatedPeersMap[newPeerID] = struct{}{}
b.ResetTimer()
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerAddedIncremental(account, newRouterID)
require.NoError(t, err, "error adding router to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
b.ResetTimer()
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newRouterID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
delete(validatedPeersMap, deletedPeerID)
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerDeleted(account, deletedPeerID)
require.NoError(t, err, "error deleting peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
}
}
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedRouterID := "peer-75"
var affectedRoute *route.Route
for _, r := range account.Routes {
if r.PeerID == deletedRouterID {
affectedRoute = r
break
}
}
require.NotNil(t, affectedRoute, "Router peer should have a route")
for _, group := range account.Groups {
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
return id == deletedRouterID
})
}
for routeID, r := range account.Routes {
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
delete(account.Routes, routeID)
}
}
delete(account.Peers, deletedRouterID)
delete(validatedPeersMap, deletedRouterID)
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerDeleted(account, deletedRouterID)
require.NoError(t, err, "error deleting routing peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool {
return id == deletedPeerID
})
account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool {
return id == deletedPeerID
})
delete(validatedPeersMap, deletedPeerID)
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
b.ResetTimer()
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerDeleted(account, deletedPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
for _, peer := range networkMap.Peers {
if peer.Status != nil {
peer.Status.LastSeen = time.Time{}
}
peer.LastLogin = &time.Time{}
}
for _, peer := range networkMap.OfflinePeers {
if peer.Status != nil {
peer.Status.LastSeen = time.Time{}
}
peer.LastLogin = &time.Time{}
}
sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID })
sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID })
sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID })
sort.Slice(networkMap.FirewallRules, func(i, j int) bool {
r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j]
if r1.PeerIP != r2.PeerIP {
return r1.PeerIP < r2.PeerIP
}
if r1.Protocol != r2.Protocol {
return r1.Protocol < r2.Protocol
}
if r1.Direction != r2.Direction {
return r1.Direction < r2.Direction
}
if r1.Action != r2.Action {
return r1.Action < r2.Action
}
return r1.Port < r2.Port
})
sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool {
r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j]
if r1.RouteID != r2.RouteID {
return r1.RouteID < r2.RouteID
}
if r1.Action != r2.Action {
return r1.Action < r2.Action
}
if r1.Destination != r2.Destination {
return r1.Destination < r2.Destination
}
if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 {
if r1.SourceRanges[0] != r2.SourceRanges[0] {
return r1.SourceRanges[0] < r2.SourceRanges[0]
}
}
return r1.Port < r2.Port
})
for _, ranges := range networkMap.RoutesFirewallRules {
sort.Slice(ranges.SourceRanges, func(i, j int) bool {
return ranges.SourceRanges[i] < ranges.SourceRanges[j]
})
}
}
type networkMapJSON struct {
Peers []*nbpeer.Peer `json:"Peers"`
Network *types.Network `json:"Network"`
Routes []*route.Route `json:"Routes"`
DNSConfig dns.Config `json:"DNSConfig"`
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
EnableSSH bool `json:"EnableSSH"`
}
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
result := &networkMapJSON{
Peers: nm.Peers,
Network: nm.Network,
Routes: nm.Routes,
DNSConfig: nm.DNSConfig,
OfflinePeers: nm.OfflinePeers,
FirewallRules: nm.FirewallRules,
RoutesFirewallRules: nm.RoutesFirewallRules,
ForwardingRules: nm.ForwardingRules,
EnableSSH: nm.EnableSSH,
}
if len(nm.AuthorizedUsers) > 0 {
result.AuthorizedUsers = make(map[string][]string)
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
for localUser := range nm.AuthorizedUsers {
localUsers = append(localUsers, localUser)
}
sort.Strings(localUsers)
for _, localUser := range localUsers {
userIDs := nm.AuthorizedUsers[localUser]
sortedUserIDs := make([]string, 0, len(userIDs))
for userID := range userIDs {
sortedUserIDs = append(sortedUserIDs, userID)
}
sort.Strings(sortedUserIDs)
result.AuthorizedUsers[localUser] = sortedUserIDs
}
}
return result
}
func createTestAccountWithEntities() *types.Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*types.Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
}
policies := []*types.Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*types.PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
}},
},
{
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
users := map[string]*types.User{
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
}
account := &types.Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*dns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
time.Sleep(100 * time.Millisecond)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
t.Log("Update golden file with OnPeerAdded router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
}

File diff suppressed because it is too large Load Diff

View File

@@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })
@@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
select { select {
case <-done: case <-done:
case <-time.After(time.Second): case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate") t.Error("timeout waiting for peerShouldReceiveUpdate")
} }
}) })

View File

@@ -0,0 +1,409 @@
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
grpcstatus "google.golang.org/grpc/status"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/proto"
)
type byopTestSetup struct {
store store.Store
proxyService *nbgrpc.ProxyServiceServer
grpcServer *grpc.Server
grpcAddr string
cleanup func()
accountA string
accountB string
accountAToken types.PlainProxyToken
accountBToken types.PlainProxyToken
accountACluster string
accountBCluster string
}
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
accountAID := "byop-account-a"
accountBID := "byop-account-b"
for _, acc := range []*types.Account{
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
} {
require.NoError(t, testStore.SaveAccount(ctx, acc))
}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
clusterA := "byop-a.proxy.test"
clusterB := "byop-b.proxy.test"
services := []*service.Service{
{
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
},
{
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
},
{
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
},
}
for _, svc := range services {
require.NoError(t, testStore.CreateService(ctx, svc))
}
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
require.NoError(t, err)
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
require.NoError(t, err)
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
meter := noop.NewMeterProvider().Meter("test")
realProxyManager, err := proxymanager.NewManager(testStore, meter)
require.NoError(t, err)
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: "https://fake-issuer.example.com",
ClientID: "test-client",
HMACKey: []byte("test-hmac-key"),
}
usersManager := users.NewManager(testStore)
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
pkceStore,
oidcConfig,
nil,
usersManager,
realProxyManager,
nil,
)
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
proto.RegisterProxyServiceServer(grpcServer, proxyService)
go func() {
if err := grpcServer.Serve(lis); err != nil {
t.Logf("gRPC server error: %v", err)
}
}()
return &byopTestSetup{
store: testStore,
proxyService: proxyService,
grpcServer: grpcServer,
grpcAddr: lis.Addr().String(),
cleanup: func() {
grpcServer.GracefulStop()
authClose()
storeCleanup()
},
accountA: accountAID,
accountB: accountBID,
accountAToken: tokenA.PlainToken,
accountBToken: tokenB.PlainToken,
accountACluster: clusterA,
accountBCluster: clusterB,
}
}
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
md := metadata.Pairs("authorization", "Bearer "+string(token))
return metadata.NewOutgoingContext(ctx, md)
}
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
t.Helper()
var mappings []*proto.ProxyMapping
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings := receiveBYOPMappings(t, stream)
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
for _, m := range mappings {
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
}
ids := map[string]bool{}
for _, m := range mappings {
ids[m.GetId()] = true
}
assert.True(t, ids["svc-a1"], "should contain svc-a1")
assert.True(t, ids["svc-a2"], "should contain svc-a2")
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
}
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-b",
Version: "test-v1",
Address: setup.accountBCluster,
})
require.NoError(t, err)
mappings := receiveBYOPMappings(t, stream)
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
assert.Equal(t, "svc-b1", mappings[0].GetId())
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
}
func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel1()
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-first",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings1 := receiveBYOPMappings(t, stream1)
assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services")
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-second",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
mappings2 := receiveBYOPMappings(t, stream2)
assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services")
for _, m := range mappings2 {
assert.Equal(t, setup.accountA, m.GetAccountId())
}
}
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel1()
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-a-cluster",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_ = receiveBYOPMappings(t, stream1)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: "byop-proxy-b-conflict",
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
_, err = stream2.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
t.Logf("expected rejection: %s", st.Message())
}
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
proxyID := "byop-proxy-reconnect"
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
firstMappings := receiveBYOPMappings(t, stream1)
cancel1()
time.Sleep(200 * time.Millisecond)
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: setup.accountACluster,
})
require.NoError(t, err)
secondMappings := receiveBYOPMappings(t, stream2)
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
firstIDs := map[string]bool{}
for _, m := range firstMappings {
firstIDs[m.GetId()] = true
}
for _, m := range secondMappings {
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
}
}
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
setup := setupBYOPIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "no-auth-proxy",
Version: "test-v1",
Address: "some.cluster.io",
})
require.NoError(t, err)
_, err = stream.Recv()
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unauthenticated, st.Code())
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -140,6 +141,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
nil, nil,
usersManager, usersManager,
proxyManager, proxyManager,
nil,
) )
// Use store-backed service manager // Use store-backed service manager
@@ -201,7 +203,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
// testProxyManager is a mock implementation of proxy.Manager for testing. // testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{} type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error { func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *nbproxy.Capabilities) error {
return nil return nil
} }
@@ -217,6 +219,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
return nil, nil return nil, nil
} }
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
return nil, nil return nil, nil
} }
@@ -237,6 +243,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
return nil return nil
} }
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing. // testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{} type testProxyController struct{}
@@ -336,6 +358,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
return nil, nil return nil, nil
} }

View File

@@ -3323,10 +3323,64 @@ components:
example: false example: false
required: required:
- enabled - enabled
ProxyTokenRequest:
type: object
properties:
name:
type: string
description: Human-readable token name
example: "my-proxy-token"
expires_in:
type: integer
minimum: 0
description: Token expiration in seconds (0 = never expires)
example: 0
required:
- name
ProxyToken:
type: object
properties:
id:
type: string
name:
type: string
expires_at:
type: string
format: date-time
created_at:
type: string
format: date-time
last_used:
type: string
format: date-time
revoked:
type: boolean
required:
- id
- name
- created_at
- revoked
ProxyTokenCreated:
type: object
description: Returned on creation — plain_token is shown only once
allOf:
- $ref: '#/components/schemas/ProxyToken'
- type: object
properties:
plain_token:
type: string
description: The plain text token (shown only once)
example: "nbx_abc123..."
required:
- plain_token
ProxyCluster: ProxyCluster:
type: object type: object
description: A proxy cluster represents a group of proxy nodes serving the same address description: A proxy cluster represents a group of proxy nodes serving the same address
properties: properties:
id:
type: string
description: Unique identifier of a proxy in this cluster
example: "chlfq4q5r8kc73b0qjpg"
address: address:
type: string type: string
description: Cluster address used for CNAME targets description: Cluster address used for CNAME targets
@@ -3335,9 +3389,15 @@ components:
type: integer type: integer
description: Number of proxy nodes connected in this cluster description: Number of proxy nodes connected in this cluster
example: 3 example: 3
self_hosted:
type: boolean
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
example: false
required: required:
- id
- address - address
- connected_proxies - connected_proxies
- self_hosted
ReverseProxyDomainType: ReverseProxyDomainType:
type: string type: string
description: Type of Reverse Proxy Domain description: Type of Reverse Proxy Domain
@@ -11317,6 +11377,111 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
/api/reverse-proxies/clusters/{clusterId}:
delete:
summary: Delete a self-hosted proxy cluster
description: Removes a self-hosted (BYOP) proxy cluster and disconnects it. Only self-hosted clusters can be deleted.
tags: [ Services ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: clusterId
required: true
schema:
type: string
description: The unique identifier of the proxy cluster
responses:
'200':
description: Proxy cluster deleted successfully
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/proxy-tokens:
get:
summary: List Proxy Tokens
description: Returns all proxy access tokens for the account
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of proxy tokens
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/ProxyToken'
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Proxy Token
description: Generate an account-scoped proxy access token for self-hosted proxy registration
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/ProxyTokenRequest'
responses:
'200':
description: Proxy token created (plain token shown once)
content:
application/json:
schema:
$ref: '#/components/schemas/ProxyTokenCreated'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/proxy-tokens/{tokenId}:
delete:
summary: Revoke a Proxy Token
description: Revoke an account-scoped proxy access token
tags: [ Self-Hosted Proxies ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: tokenId
required: true
schema:
type: string
description: The unique identifier of the proxy token
responses:
'200':
description: Token revoked
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'404':
"$ref": "#/components/responses/not_found"
'500':
"$ref": "#/components/responses/internal_error"
/api/reverse-proxies/services: /api/reverse-proxies/services:
get: get:
summary: List all Services summary: List all Services

View File

@@ -3761,11 +3761,49 @@ type ProxyAccessLogsResponse struct {
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address // ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
type ProxyCluster struct { type ProxyCluster struct {
// Id Unique identifier of a proxy in this cluster
Id string `json:"id"`
// Address Cluster address used for CNAME targets // Address Cluster address used for CNAME targets
Address string `json:"address"` Address string `json:"address"`
// ConnectedProxies Number of proxy nodes connected in this cluster // ConnectedProxies Number of proxy nodes connected in this cluster
ConnectedProxies int `json:"connected_proxies"` ConnectedProxies int `json:"connected_proxies"`
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
SelfHosted bool `json:"self_hosted"`
}
// ProxyToken defines model for ProxyToken.
type ProxyToken struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Id string `json:"id"`
LastUsed *time.Time `json:"last_used,omitempty"`
Name string `json:"name"`
Revoked bool `json:"revoked"`
}
// ProxyTokenCreated defines model for ProxyTokenCreated.
type ProxyTokenCreated struct {
CreatedAt time.Time `json:"created_at"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
Id string `json:"id"`
LastUsed *time.Time `json:"last_used,omitempty"`
Name string `json:"name"`
// PlainToken The plain text token (shown only once)
PlainToken string `json:"plain_token"`
Revoked bool `json:"revoked"`
}
// ProxyTokenRequest defines model for ProxyTokenRequest.
type ProxyTokenRequest struct {
// ExpiresIn Token expiration in seconds (0 = never expires)
ExpiresIn *int `json:"expires_in,omitempty"`
// Name Human-readable token name
Name string `json:"name"`
} }
// Resource defines model for Resource. // Resource defines model for Resource.
@@ -5127,6 +5165,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType. // PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType. // PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest