mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-27 20:56:44 +00:00
Compare commits
10 Commits
dependabot
...
feat/pat-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59569c5147 | ||
|
|
0c71aca86d | ||
|
|
be9f1b46e6 | ||
|
|
154b81645a | ||
|
|
34167c8a16 | ||
|
|
d6f08e4840 | ||
|
|
f732b01a05 | ||
|
|
c07c726ea7 | ||
|
|
fa0d58d093 | ||
|
|
b6038e8acd |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.2"
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,11 +15,9 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -58,13 +55,6 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
|
||||
compactedNetworkMap bool
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
}
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
compactedNetworkMap := true
|
||||
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
||||
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
||||
if err != nil && len(compactedEnv) > 0 {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
||||
}
|
||||
if err == nil && !parsedCompactedNmap {
|
||||
log.WithContext(ctx).Info("disabling compacted mode")
|
||||
compactedNetworkMap = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
return &Controller{
|
||||
repo: newRepository(store),
|
||||
metrics: nMetrics,
|
||||
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
|
||||
compactedNetworkMap: compactedNetworkMap,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountID):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||
}
|
||||
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return err
|
||||
}
|
||||
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
switch {
|
||||
case c.experimentalNetworkMap(accountId):
|
||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
case c.compactedNetworkMap:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
default:
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, emptyMap, nil, 0, nil
|
||||
}
|
||||
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||
} else {
|
||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||
} else {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (c *Controller) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *types.NetworkMap {
|
||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
|
||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||
c.enrichAccountFromHolder(account)
|
||||
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
||||
}
|
||||
|
||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
c.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := c.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
c.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if c.experimentalNetworkMap(accountId) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||
return c.expNewNetworkMap || ok
|
||||
}
|
||||
|
||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||
a := c.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
c.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||
return c.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||
a := c.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||
c.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
// GetDNSDomain returns the configured dnsDomain
|
||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||
if settings == nil {
|
||||
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||
}
|
||||
|
||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
||||
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if c.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||
} else {
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
if c.compactedNetworkMap {
|
||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
account.InjectProxyPolicies(ctx)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
|
||||
@@ -12,9 +12,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
|
||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||
OldForwarderPort = nbdns.ForwarderClientPort
|
||||
DnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
limiter := middleware.NewAPIRateLimiter(cfg)
|
||||
limiter.SetEnabled(enabled)
|
||||
return limiter
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
|
||||
@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -1171,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
@@ -1231,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
@@ -1274,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
@@ -1332,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
@@ -1397,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
@@ -1633,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
}
|
||||
|
||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account := &types.Account{
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
||||
},
|
||||
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
||||
Routes: map[route.ID]*route.Route{
|
||||
"route-1": {
|
||||
ID: "route-1",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-1",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-2": {
|
||||
ID: "route-2",
|
||||
Network: prefix2,
|
||||
NetID: "network-2",
|
||||
Description: "network-2",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
"route-3": {
|
||||
ID: "route-3",
|
||||
Network: prefix,
|
||||
NetID: "network-1",
|
||||
Description: "network-1",
|
||||
Peer: "peer-2",
|
||||
NetworkType: 0,
|
||||
Masquerade: false,
|
||||
Metric: 999,
|
||||
Enabled: true,
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
for _, r := range routes {
|
||||
routeIDs[r.ID] = struct{}{}
|
||||
}
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
|
||||
func TestAccount_Copy(t *testing.T) {
|
||||
account := &types.Account{
|
||||
Id: "account1",
|
||||
@@ -1824,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
|
||||
AccountID: "account1",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -2311,6 +2215,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
// Verify the already-expired peer is excluded at the store level
|
||||
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
|
||||
}
|
||||
|
||||
// Only the non-expired peer with expiration enabled should be returned
|
||||
require.Len(t, peers, 1)
|
||||
assert.Equal(t, "notexpired01", peers[0].ID)
|
||||
}
|
||||
|
||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
@@ -3230,6 +3157,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
||||
// seconds. Only runs down on failure; passing tests return immediately
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3248,7 +3182,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -5,9 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
@@ -65,15 +62,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
)
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -94,34 +86,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
if rateLimiter == nil {
|
||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
@@ -129,7 +97,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
)
|
||||
|
||||
@@ -171,7 +139,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddEndpoints(instanceManager, accountManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -15,13 +16,15 @@ import (
|
||||
// handler handles the instance setup HTTP endpoints
|
||||
type handler struct {
|
||||
instanceManager nbinstance.Manager
|
||||
setupManager *nbinstance.SetupService
|
||||
}
|
||||
|
||||
// AddEndpoints registers the instance setup endpoints.
|
||||
// These endpoints bypass authentication for initial setup.
|
||||
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
setupManager: nbinstance.NewSetupService(instanceManager, accountManager),
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
|
||||
@@ -55,24 +58,35 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
// setup creates the initial admin user for the instance.
|
||||
// This endpoint is unauthenticated but only works when setup is required.
|
||||
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var req api.SetupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
|
||||
result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{
|
||||
CreatePAT: req.CreatePat != nil && *req.CreatePat,
|
||||
PATExpireInDays: req.PatExpireIn,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
|
||||
log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
|
||||
UserId: userData.ID,
|
||||
Email: userData.Email,
|
||||
})
|
||||
resp := api.SetupResponse{
|
||||
UserId: result.User.ID,
|
||||
Email: result.User.Email,
|
||||
}
|
||||
|
||||
if result.PATPlainToken != "" {
|
||||
resp.PersonalAccessToken = &result.PATPlainToken
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
|
||||
@@ -10,12 +10,18 @@ import (
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbstore "github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -25,6 +31,7 @@ type mockInstanceManager struct {
|
||||
isSetupRequired bool
|
||||
isSetupRequiredFn func(ctx context.Context) (bool, error)
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
rollbackSetupFn func(ctx context.Context, userID string) error
|
||||
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
|
||||
}
|
||||
|
||||
@@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error {
|
||||
if m.rollbackSetupFn != nil {
|
||||
return m.rollbackSetupFn(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
|
||||
if m.getVersionInfoFn != nil {
|
||||
return m.getVersionInfoFn(ctx)
|
||||
@@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V
|
||||
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
|
||||
|
||||
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
|
||||
return setupTestRouterWithPAT(manager, nil)
|
||||
}
|
||||
|
||||
func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
AddEndpoints(manager, router)
|
||||
AddEndpoints(manager, accountManager, router)
|
||||
return router
|
||||
}
|
||||
|
||||
@@ -293,6 +311,222 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "false")
|
||||
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
// NB_SETUP_PAT_ENABLED=false: request fields must be silently ignored
|
||||
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Nil(t, response.PersonalAccessToken)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Nil(t, response.PersonalAccessToken)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
createCalled := false
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
createCalled = true
|
||||
return &idp.UserData{ID: "u1", Email: email, Name: name}, nil
|
||||
},
|
||||
}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "u1", userAuth.UserId)
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "u1", initiator)
|
||||
assert.Equal(t, "u1", target)
|
||||
assert.Equal(t, "setup-token", name)
|
||||
assert.Equal(t, 1, expiresIn)
|
||||
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
|
||||
},
|
||||
}
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.True(t, createCalled)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
require.NotNil(t, response.PersonalAccessToken)
|
||||
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_ExpireOutOfRange(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{})
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_Success(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
}
|
||||
|
||||
gotAccountArgs := struct {
|
||||
userID string
|
||||
email string
|
||||
}{}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
gotAccountArgs.userID = userAuth.UserId
|
||||
gotAccountArgs.email = userAuth.Email
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "owner-id", initiator)
|
||||
assert.Equal(t, "owner-id", target)
|
||||
assert.Equal(t, "setup-token", name)
|
||||
assert.Equal(t, 30, expiresIn)
|
||||
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Equal(t, "owner-id", response.UserId)
|
||||
require.NotNil(t, response.PersonalAccessToken)
|
||||
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
|
||||
assert.Equal(t, "owner-id", gotAccountArgs.userID)
|
||||
}
|
||||
|
||||
func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
rolledBackFor := ""
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "", errors.New("db down")
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id")
|
||||
}
|
||||
|
||||
func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
|
||||
|
||||
rolledBackFor := ""
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
}
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
return nil, status.Errorf(status.Internal, "token store unavailable")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails")
|
||||
}
|
||||
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
router := mux.NewRouter()
|
||||
|
||||
@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||
|
||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
@@ -43,14 +43,9 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
var err error
|
||||
@@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,14 +4,27 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
const (
|
||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||
|
||||
defaultAPIRPM = 6
|
||||
defaultAPIBurst = 500
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||
rpm := defaultAPIRPM
|
||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
if rpm <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
|
||||
rpm = defaultAPIRPM
|
||||
}
|
||||
|
||||
burst := defaultAPIBurst
|
||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
if burst <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
|
||||
burst = defaultAPIBurst
|
||||
}
|
||||
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
rl.enabled.Store(true)
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||
rl.enabled.Store(enabled)
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) Enabled() bool {
|
||||
return rl.enabled.Load()
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
|
||||
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
|
||||
return
|
||||
}
|
||||
|
||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||
newBurst := config.Burst
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||
rl.config.Burst = newBurst
|
||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||
for _, entry := range rl.limiters {
|
||||
snapshot = append(snapshot, entry.limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
for _, l := range snapshot {
|
||||
l.SetLimit(newRPS)
|
||||
l.SetBurst(newBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
if !rl.enabled.Load() {
|
||||
return true
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
if !rl.enabled.Load() {
|
||||
return nil
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.enabled.Load() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("key"))
|
||||
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
||||
|
||||
rl.SetEnabled(false)
|
||||
assert.False(t, rl.Enabled())
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
||||
}
|
||||
|
||||
rl.SetEnabled(true)
|
||||
assert.True(t, rl.Enabled())
|
||||
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
|
||||
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
||||
// but importantly newly-added keys use the updated config immediately.
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
for i := 0; i < 9; i++ {
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
}
|
||||
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
rl.UpdateConfig(nil) // must not panic or zero the config
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
|
||||
rl.Reset("k")
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 600,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("k%d", id)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.Allow(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 200; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: float64(30 + (i % 90)),
|
||||
Burst: 1 + (i % 20),
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
rl.SetEnabled(i%2 == 0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
||||
t.Setenv(RateLimitingEnabledEnv, "true")
|
||||
t.Setenv(RateLimitingRPMEnv, "42")
|
||||
t.Setenv(RateLimitingBurstEnv, "7")
|
||||
|
||||
cfg, enabled := RateLimiterConfigFromEnv()
|
||||
assert.True(t, enabled)
|
||||
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, 7, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "false")
|
||||
_, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "")
|
||||
t.Setenv(RateLimitingRPMEnv, "")
|
||||
t.Setenv(RateLimitingBurstEnv, "")
|
||||
cfg, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingRPMEnv, "0")
|
||||
t.Setenv(RateLimitingBurstEnv, "-5")
|
||||
cfg, _ = RateLimiterConfigFromEnv()
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -60,6 +61,13 @@ type Manager interface {
|
||||
// This should only be called when IsSetupRequired returns true.
|
||||
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
|
||||
// RollbackSetup reverses a successful CreateOwnerUser by deleting the user
|
||||
// from the embedded IDP and reloading setupRequired from persistent state, so
|
||||
// /api/setup can be retried only when no accounts or local users remain. Used
|
||||
// when post-user steps (account or PAT creation) fail and the caller wants a
|
||||
// clean slate.
|
||||
RollbackSetup(ctx context.Context, userID string) error
|
||||
|
||||
// GetVersionInfo returns version information for NetBird components.
|
||||
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
|
||||
}
|
||||
@@ -70,6 +78,7 @@ type instanceStore interface {
|
||||
|
||||
type embeddedIdP interface {
|
||||
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
|
||||
}
|
||||
|
||||
@@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the
|
||||
// embedded IDP and reloads setupRequired from persistent state.
|
||||
func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error {
|
||||
if m.embeddedIdpManager == nil {
|
||||
return errors.New("embedded IDP is not enabled")
|
||||
}
|
||||
|
||||
var deleteErr error
|
||||
if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil {
|
||||
if isNotFoundError(err) {
|
||||
log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID)
|
||||
} else {
|
||||
deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.loadSetupRequired(ctx); err != nil {
|
||||
reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err)
|
||||
if deleteErr != nil {
|
||||
return errors.Join(deleteErr, reloadErr)
|
||||
}
|
||||
return reloadErr
|
||||
}
|
||||
|
||||
if deleteErr != nil {
|
||||
return deleteErr
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("rolled back setup for user %s", userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return true
|
||||
}
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return s.Type() == status.NotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
|
||||
numAccounts, err := m.store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -10,16 +10,19 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type mockIdP struct {
|
||||
mu sync.Mutex
|
||||
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
users map[string][]*idp.UserData
|
||||
mu sync.Mutex
|
||||
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
deleteUserFunc func(ctx context.Context, userID string) error
|
||||
users map[string][]*idp.UserData
|
||||
getAllAccountsErr error
|
||||
}
|
||||
|
||||
@@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n
|
||||
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
|
||||
}
|
||||
|
||||
func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error {
|
||||
if m.deleteUserFunc != nil {
|
||||
return m.deleteUserFunc(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
|
||||
assert.False(t, required)
|
||||
}
|
||||
|
||||
func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "management status not found",
|
||||
err: status.NewUserNotFoundError("owner-id"),
|
||||
},
|
||||
{
|
||||
name: "dex storage not found",
|
||||
err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
idpMock := &mockIdP{
|
||||
deleteUserFunc: func(_ context.Context, userID string) error {
|
||||
assert.Equal(t, "owner-id", userID)
|
||||
return tt.err
|
||||
},
|
||||
}
|
||||
mgr := newTestManager(idpMock, &mockStore{})
|
||||
mgr.setupRequired = false
|
||||
|
||||
err := mgr.RollbackSetup(context.Background(), "owner-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
required, err := mgr.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, required, "setup should be required when no accounts or local users remain")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) {
|
||||
idpMock := &mockIdP{
|
||||
deleteUserFunc: func(_ context.Context, _ string) error {
|
||||
return status.NewUserNotFoundError("owner-id")
|
||||
},
|
||||
}
|
||||
mgr := newTestManager(idpMock, &mockStore{accountsCount: 1})
|
||||
mgr.setupRequired = true
|
||||
|
||||
err := mgr.RollbackSetup(context.Background(), "owner-id")
|
||||
require.NoError(t, err)
|
||||
|
||||
required, err := mgr.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.False(t, required, "setup should not be required while an account still exists")
|
||||
}
|
||||
|
||||
func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) {
|
||||
idpMock := &mockIdP{
|
||||
deleteUserFunc: func(_ context.Context, _ string) error {
|
||||
return errors.New("idp unavailable")
|
||||
},
|
||||
}
|
||||
mgr := newTestManager(idpMock, &mockStore{})
|
||||
mgr.setupRequired = false
|
||||
|
||||
err := mgr.RollbackSetup(context.Background(), "owner-id")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "idp unavailable")
|
||||
|
||||
required, err := mgr.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, required, "setup state should be reloaded even when user deletion fails")
|
||||
}
|
||||
|
||||
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
|
||||
manager := &DefaultManager{setupRequired: true}
|
||||
|
||||
|
||||
185
management/server/instance/setup_service.go
Normal file
185
management/server/instance/setup_service.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
setupPATTokenName = "setup-token"
|
||||
|
||||
// SetupPATEnabledEnvKey enables setup-time Personal Access Token creation.
|
||||
SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED"
|
||||
|
||||
setupPATDefaultExpireDays = 1
|
||||
|
||||
// patMinExpireDays and patMaxExpireDays mirror the bounds enforced by
|
||||
// DefaultAccountManager.CreatePAT in management/server/user.go. They are
|
||||
// duplicated here so /api/setup can reject invalid input before it creates
|
||||
// the embedded-IdP user.
|
||||
patMinExpireDays = 1
|
||||
patMaxExpireDays = 365
|
||||
)
|
||||
|
||||
// SetupOptions controls optional work performed during initial instance setup.
|
||||
type SetupOptions struct {
|
||||
// CreatePAT requests creation of a setup Personal Access Token. It is honored
|
||||
// only when SetupPATEnabledEnvKey is set to "true".
|
||||
CreatePAT bool
|
||||
// PATExpireInDays defaults to 1 day when CreatePAT is requested and setup PAT
|
||||
// creation is enabled.
|
||||
PATExpireInDays *int
|
||||
}
|
||||
|
||||
// SetupResult contains resources created during initial instance setup.
|
||||
type SetupResult struct {
|
||||
User *idp.UserData
|
||||
PATPlainToken string
|
||||
}
|
||||
|
||||
// SetupService orchestrates the initial setup use case across the instance and
|
||||
// account bounded contexts and owns the compensation logic when a later step
|
||||
// fails.
|
||||
type SetupService struct {
|
||||
instanceManager Manager
|
||||
accountManager account.Manager
|
||||
setupPATEnabled bool
|
||||
}
|
||||
|
||||
// NewSetupService creates a setup use-case service.
|
||||
func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService {
|
||||
return &SetupService{
|
||||
instanceManager: instanceManager,
|
||||
accountManager: accountManager,
|
||||
setupPATEnabled: os.Getenv(SetupPATEnabledEnvKey) == "true",
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOptions, error) {
|
||||
if !opts.CreatePAT {
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
if !setupPATEnabled {
|
||||
opts.CreatePAT = false
|
||||
opts.PATExpireInDays = nil
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
if opts.PATExpireInDays == nil {
|
||||
defaultExpireInDays := setupPATDefaultExpireDays
|
||||
opts.PATExpireInDays = &defaultExpireInDays
|
||||
}
|
||||
|
||||
if *opts.PATExpireInDays < patMinExpireDays || *opts.PATExpireInDays > patMaxExpireDays {
|
||||
return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", patMinExpireDays, patMaxExpireDays)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// SetupOwner creates the initial owner user and, when requested and enabled by
|
||||
// SetupPATEnabledEnvKey, provisions the account and a setup Personal Access
|
||||
// Token. If account or PAT provisioning fails, created resources are rolled
|
||||
// back so setup can be retried.
|
||||
func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) {
|
||||
opts, err := normalizeSetupOptions(opts, m.setupPATEnabled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SetupResult{User: userData}
|
||||
if !opts.CreatePAT {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if m.accountManager == nil {
|
||||
err := fmt.Errorf("account manager is required to create setup PAT")
|
||||
m.rollbackSetup(ctx, userData.ID, "setup PAT requested without account manager", err, "")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: userData.ID,
|
||||
Email: userData.Email,
|
||||
Name: userData.Name,
|
||||
}
|
||||
|
||||
accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create account for setup user: %w", err)
|
||||
m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, "")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, *opts.PATExpireInDays)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create setup PAT: %w", err)
|
||||
m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result.PATPlainToken = pat.PlainToken
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) {
|
||||
if accountID != "" {
|
||||
if err := m.rollbackSetupAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, err)
|
||||
} else {
|
||||
log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr)
|
||||
}
|
||||
|
||||
// rollbackSetupAccount removes only the setup-created account data from the
|
||||
// store. It intentionally avoids accountManager.DeleteAccount because the normal
|
||||
// account deletion path also deletes users from the IdP; embedded IdP cleanup is
|
||||
// owned by instanceManager.RollbackSetup.
|
||||
func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error {
|
||||
if m.accountManager == nil {
|
||||
return fmt.Errorf("account manager is required to roll back setup account")
|
||||
}
|
||||
|
||||
accountStore := m.accountManager.GetStore()
|
||||
if accountStore == nil {
|
||||
return fmt.Errorf("account store is unavailable")
|
||||
}
|
||||
|
||||
account, err := accountStore.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("get setup account for rollback: %w", err)
|
||||
}
|
||||
|
||||
if err := accountStore.DeleteAccount(ctx, account); err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete setup account for rollback: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
241
management/server/instance/setup_service_test.go
Normal file
241
management/server/instance/setup_service_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbstore "github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type setupInstanceManagerMock struct {
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
rollbackSetupFn func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
if m.createOwnerUserFn != nil {
|
||||
return m.createOwnerUserFn(ctx, email, password, name)
|
||||
}
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error {
|
||||
if m.rollbackSetupFn != nil {
|
||||
return m.rollbackSetupFn(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var _ Manager = (*setupInstanceManagerMock)(nil)
|
||||
|
||||
func intPtr(v int) *int {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestSetupOwner_PATFeatureDisabled_IgnoresCreatePAT(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "false")
|
||||
|
||||
createCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
|
||||
createCalls++
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "owner-id", result.User.ID)
|
||||
assert.Empty(t, result.PATPlainToken)
|
||||
assert.Equal(t, 1, createCalls)
|
||||
}
|
||||
|
||||
func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
createCalled := false
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
|
||||
createCalled = true
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "owner-id", userAuth.UserId)
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "owner-id", initiatorUserID)
|
||||
assert.Equal(t, "owner-id", targetUserID)
|
||||
assert.Equal(t, setupPATTokenName, tokenName)
|
||||
assert.Equal(t, setupPATDefaultExpireDays, expiresIn)
|
||||
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.True(t, createCalled)
|
||||
assert.Equal(t, "nbp_plain", result.PATPlainToken)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
|
||||
|
||||
rollbackCalls := 0
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rollbackCalls++
|
||||
assert.Equal(t, "owner-id", userID)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "owner-id", userAuth.UserId)
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
assert.Equal(t, "acc-1", accountID)
|
||||
assert.Equal(t, "owner-id", initiatorUserID)
|
||||
assert.Equal(t, "owner-id", targetUserID)
|
||||
assert.Equal(t, setupPATTokenName, tokenName)
|
||||
assert.Equal(t, 30, expiresIn)
|
||||
return nil, status.Errorf(status.Internal, "token store unavailable")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create setup PAT")
|
||||
assert.Equal(t, 1, rollbackCalls)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1"))
|
||||
|
||||
rolledBackFor := ""
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
return nil, errors.New("token failure")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create setup PAT")
|
||||
assert.Equal(t, "owner-id", rolledBackFor)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_AccountRollbackFailureStillRollsBackUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed"))
|
||||
|
||||
rolledBackFor := ""
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "acc-1", nil
|
||||
},
|
||||
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
|
||||
return nil, errors.New("token failure")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create setup PAT")
|
||||
assert.Equal(t, "owner-id", rolledBackFor)
|
||||
}
|
||||
@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
// expired peers come separately.
|
||||
if len(networkMap.GetOfflinePeers()) != 1 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
|
||||
if len(networkMap.GetOfflinePeers()) != 2 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
|
||||
}
|
||||
|
||||
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
||||
|
||||
@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
|
||||
@@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func testGetNetworkMapGeneral(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
if err != nil {
|
||||
@@ -1016,11 +1011,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
@@ -1600,7 +1590,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
@@ -1907,7 +1896,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1929,7 +1918,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1994,7 +1983,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2012,7 +2001,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2058,7 +2047,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2076,7 +2065,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2113,7 +2102,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2131,7 +2120,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -2,10 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1840,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for p := range account.Peers {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
t.Run("check applied policies for the route", func(t *testing.T) {
|
||||
route1 := account.Routes["route1"]
|
||||
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
||||
@@ -1858,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
||||
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
||||
assert.Len(t, policies, 0)
|
||||
})
|
||||
|
||||
t.Run("check peer routes firewall rules", func(t *testing.T) {
|
||||
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 4)
|
||||
|
||||
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 80,
|
||||
RouteID: "route1:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 320,
|
||||
RouteID: "route1:peerA",
|
||||
},
|
||||
}
|
||||
additionalFirewallRule := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
Protocol: "tcp",
|
||||
Port: 80,
|
||||
RouteID: "route4:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.10.0/16",
|
||||
Protocol: "all",
|
||||
RouteID: "route4:peerA",
|
||||
},
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
|
||||
|
||||
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 2)
|
||||
for _, rule := range expectedRoutesFirewallRules {
|
||||
rule.RouteID = "route1:peerD"
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerE is a single routing peer for route 2 and route 3
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 3)
|
||||
|
||||
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
||||
Action: "accept",
|
||||
Destination: existingNetwork.String(),
|
||||
Protocol: "tcp",
|
||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
||||
RouteID: "route2",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{"0.0.0.0/0"},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "all",
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "route3",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{"::/0"},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "all",
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "route3",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
||||
|
||||
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
|
||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
assert.Len(t, routesFirewallRules, 0)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// orderList is a helper function to sort a list of strings
|
||||
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
|
||||
for _, rule := range ruleList {
|
||||
sort.Strings(rule.SourceRanges)
|
||||
}
|
||||
return ruleList
|
||||
}
|
||||
|
||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
@@ -2070,7 +1953,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -2107,7 +1990,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2127,7 +2010,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2145,7 +2028,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2185,7 +2068,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2225,7 +2108,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2665,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
validatedPeers := make(map[string]struct{})
|
||||
for p := range account.Peers {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
t.Run("validate applied policies for different network resources", func(t *testing.T) {
|
||||
// Test case: Resource1 is directly applied to the policy (policyResource1)
|
||||
policies := account.GetPoliciesForNetworkResource("resource1")
|
||||
@@ -2693,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
|
||||
policies = account.GetPoliciesForNetworkResource("resource6")
|
||||
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
|
||||
})
|
||||
|
||||
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
|
||||
resourcePoliciesMap := account.GetResourcePoliciesMap()
|
||||
resourceRoutersMap := account.GetResourceRoutersMap()
|
||||
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
|
||||
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Len(t, sourcePeers, 5)
|
||||
|
||||
expectedFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 80,
|
||||
RouteID: "resource2:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.168.0.0/16",
|
||||
Protocol: "all",
|
||||
Port: 320,
|
||||
RouteID: "resource2:peerA",
|
||||
},
|
||||
}
|
||||
|
||||
additionalFirewallRules := []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "tcp",
|
||||
Port: 80,
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "resource4:peerA",
|
||||
},
|
||||
{
|
||||
SourceRanges: []string{
|
||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
||||
},
|
||||
Action: "accept",
|
||||
Destination: "192.0.2.0/32",
|
||||
Protocol: "all",
|
||||
Domains: domain.List{"example.com"},
|
||||
IsDynamic: true,
|
||||
RouteID: "resource4:peerA",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
|
||||
|
||||
// peerD is also the routing peer for resource2
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
|
||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 2)
|
||||
for _, rule := range expectedFirewallRules {
|
||||
rule.RouteID = "resource2:peerD"
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
||||
assert.Len(t, sourcePeers, 3)
|
||||
|
||||
// peerE is a single routing peer for resource1 and resource3
|
||||
// PeerE should only receive rules for resource1 since resource3 has no applied policy
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
|
||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 1)
|
||||
assert.Len(t, sourcePeers, 2)
|
||||
|
||||
expectedFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
||||
Action: "accept",
|
||||
Destination: "10.10.10.0/24",
|
||||
Protocol: "tcp",
|
||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
||||
RouteID: "resource1:peerE",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
||||
|
||||
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
|
||||
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerL is the single routing peer for resource5
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
|
||||
assert.Len(t, routes, 1)
|
||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
|
||||
assert.Len(t, firewallRules, 1)
|
||||
assert.Len(t, sourcePeers, 1)
|
||||
|
||||
expectedFirewallRules = []*types.RouteFirewallRule{
|
||||
{
|
||||
SourceRanges: []string{"100.65.29.67/32"},
|
||||
Action: "accept",
|
||||
Destination: "10.12.12.1/32",
|
||||
Protocol: "tcp",
|
||||
Port: 8080,
|
||||
RouteID: "resource5:peerL",
|
||||
},
|
||||
}
|
||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
||||
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Len(t, sourcePeers, 0)
|
||||
|
||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
|
||||
assert.Len(t, routes, 1)
|
||||
assert.Len(t, sourcePeers, 2)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1196,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
account.NameServerGroups[ns.ID] = &ns
|
||||
}
|
||||
account.NameServerGroupsG = nil
|
||||
account.InitOnce()
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
@@ -1635,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
if sExtraIntegratedValidatorGroups.Valid {
|
||||
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
|
||||
}
|
||||
account.InitOnce()
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
@@ -3310,7 +3308,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||
|
||||
@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
{
|
||||
name: "should retrieve peers for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 4,
|
||||
expectedCount: 5,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for a non-existing account ID",
|
||||
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
name: "should filter peers by partial name",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
nameFilter: "host",
|
||||
expectedCount: 3,
|
||||
expectedCount: 4,
|
||||
},
|
||||
{
|
||||
name: "should filter peers by ip",
|
||||
@@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
expectedPeerIDs []string
|
||||
}{
|
||||
{
|
||||
name: "should retrieve peers with expiration for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
name: "should retrieve only non-expired peers with expiration enabled",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
expectedPeerIDs: []string{"notexpired01"},
|
||||
},
|
||||
{
|
||||
name: "should return no peers with expiration for a non-existing account ID",
|
||||
@@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
for i, peer := range peers {
|
||||
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
@@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
|
||||
name: "should retrieve peers for another valid account ID and user ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectedCount: 2,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for existing account ID with empty user ID",
|
||||
|
||||
@@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -27,7 +26,6 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -110,16 +108,9 @@ type Account struct {
|
||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||
|
||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
||||
nmapInitOnce *sync.Once `gorm:"-"`
|
||||
|
||||
ReverseProxyFreeDomainNonce string
|
||||
}
|
||||
|
||||
func (a *Account) InitOnce() {
|
||||
a.nmapInitOnce = &sync.Once{}
|
||||
}
|
||||
|
||||
// this class is used by gorm only
|
||||
type PrimaryAccountInfo struct {
|
||||
IsDomainPrimaryAccount bool
|
||||
@@ -155,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
||||
o.SignupFormPending == onboarding.SignupFormPending
|
||||
}
|
||||
|
||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
peerRoutesMembership := make(LookupMap)
|
||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||
}
|
||||
|
||||
for _, peer := range aclPeers {
|
||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
|
||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||
routes = append(routes, filteredRoutes...)
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||
if !found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
for _, groupID := range r.Groups {
|
||||
_, found := groupListMap[groupID]
|
||||
if found {
|
||||
filteredRoutes = append(filteredRoutes, r)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return filteredRoutes
|
||||
}
|
||||
|
||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
||||
// If the given is not a routing peer, then the lists are empty.
|
||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||
|
||||
peer := a.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
seenRoute := make(map[route.ID]struct{})
|
||||
|
||||
takeRoute := func(r *route.Route, id string) {
|
||||
if _, ok := seenRoute[r.ID]; ok {
|
||||
return
|
||||
}
|
||||
seenRoute[r.ID] = struct{}{}
|
||||
|
||||
if r.Enabled {
|
||||
r.Peer = peer.Key
|
||||
enabledRoutes = append(enabledRoutes, r)
|
||||
return
|
||||
}
|
||||
disabledRoutes = append(disabledRoutes, r)
|
||||
}
|
||||
|
||||
for _, r := range a.Routes {
|
||||
for _, groupID := range r.PeerGroups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
for _, id := range group.Peers {
|
||||
if id != peerID {
|
||||
continue
|
||||
}
|
||||
|
||||
newPeerRoute := r.Copy()
|
||||
newPeerRoute.Peer = id
|
||||
newPeerRoute.PeerGroups = nil
|
||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
||||
takeRoute(newPeerRoute, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
if r.Peer == peerID {
|
||||
takeRoute(r.Copy(), peerID)
|
||||
}
|
||||
}
|
||||
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
||||
var routes []*route.Route
|
||||
@@ -276,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||
return &NetworkMap{
|
||||
Network: a.Network.Copy(),
|
||||
}
|
||||
}
|
||||
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
for _, p := range aclPeers {
|
||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
||||
expiredPeers = append(expiredPeers, p)
|
||||
continue
|
||||
}
|
||||
peersToConnect = append(peersToConnect, p)
|
||||
}
|
||||
|
||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
|
||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||
if isRouter {
|
||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
||||
}
|
||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
||||
|
||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: dnsManagementStatus,
|
||||
}
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
||||
zones = append(zones, nbdns.CustomZone{
|
||||
Domain: peersCustomZone.Domain,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
|
||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||
zones = append(zones, filteredAccountZones...)
|
||||
|
||||
dnsUpdate.CustomZones = zones
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnectIncludingRouters,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
||||
DNSConfig: dnsUpdate,
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
EnableSSH: enableSSH,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
@@ -421,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers(
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
groupList := account.GetPeerGroups(peerID)
|
||||
|
||||
var peerNSGroups []*nbdns.NameServerGroup
|
||||
|
||||
for _, nsGroup := range account.NameServerGroups {
|
||||
if !nsGroup.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := GetPeerHostLabel(peer.Name, peerLabels)
|
||||
@@ -800,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
||||
return grps
|
||||
}
|
||||
|
||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
||||
peerGroups := a.GetPeerGroups(peerID)
|
||||
enabled := true
|
||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
||||
_, found := peerGroups[groupID]
|
||||
if found {
|
||||
enabled = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||
groupList := make(LookupMap)
|
||||
for groupID, group := range a.Groups {
|
||||
@@ -941,8 +684,6 @@ func (a *Account) Copy() *Account {
|
||||
NetworkResources: networkResources,
|
||||
Services: services,
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
nmapInitOnce: a.nmapInitOnce,
|
||||
Domains: domains,
|
||||
}
|
||||
}
|
||||
@@ -1304,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
||||
|
||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
||||
for _, route := range enabledRoutes {
|
||||
// If no access control groups are specified, accept all traffic.
|
||||
if len(route.AccessControlGroups) == 0 {
|
||||
defaultPermit := getDefaultPermit(route)
|
||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||
continue
|
||||
}
|
||||
|
||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
||||
|
||||
for _, accessGroup := range route.AccessControlGroups {
|
||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
}
|
||||
|
||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||
var fwRules []*RouteFirewallRule
|
||||
for _, policy := range policies {
|
||||
@@ -1387,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
|
||||
return distributionGroupPeers
|
||||
}
|
||||
|
||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
||||
distPeers := make(map[string]struct{})
|
||||
for _, id := range route.Groups {
|
||||
group := a.Groups[id]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pID := range group.Peers {
|
||||
distPeers[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return distPeers
|
||||
}
|
||||
|
||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
||||
var rules []*RouteFirewallRule
|
||||
|
||||
sources := []string{"0.0.0.0/0"}
|
||||
if route.Network.Addr().Is6() {
|
||||
sources = []string{"::/0"}
|
||||
}
|
||||
rule := RouteFirewallRule{
|
||||
SourceRanges: sources,
|
||||
Action: string(PolicyTrafficActionAccept),
|
||||
Destination: route.Network.String(),
|
||||
Protocol: string(PolicyRuleProtocolALL),
|
||||
Domains: route.Domains,
|
||||
IsDynamic: route.IsDynamic(),
|
||||
RouteID: route.ID,
|
||||
}
|
||||
|
||||
rules = append(rules, &rule)
|
||||
|
||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
||||
if route.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||
@@ -1508,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
||||
return resourcePolicies
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
var addSourcePeers bool
|
||||
|
||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
||||
if exists {
|
||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||
isRoutingPeer, addSourcePeers = true, true
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
var peers []string
|
||||
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||
peers = []string{policy.Rules[0].SourceResource.ID}
|
||||
} else {
|
||||
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
}
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
@@ -1658,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
||||
|
||||
var routes []*route.Route
|
||||
// distribute the resource routes only if there is policy applied to it
|
||||
if len(resourceAppliedPolicies) > 0 {
|
||||
peer := a.GetPeer(peerId)
|
||||
if peer != nil {
|
||||
routes = append(routes, resource.ToRoute(peer, router))
|
||||
}
|
||||
}
|
||||
|
||||
return routes
|
||||
}
|
||||
|
||||
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
|
||||
routers := make(map[string]map[string]*routerTypes.NetworkRouter)
|
||||
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -19,7 +17,6 @@ import (
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
|
||||
require.Len(t, result, 0)
|
||||
}
|
||||
|
||||
const (
|
||||
accID = "accountID"
|
||||
network1ID = "network1ID"
|
||||
group1ID = "group1"
|
||||
accNetResourcePeer1ID = "peer1"
|
||||
accNetResourcePeer2ID = "peer2"
|
||||
accNetResourceRouter1ID = "router1"
|
||||
accNetResource1ID = "resource1ID"
|
||||
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
|
||||
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
|
||||
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
|
||||
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
|
||||
)
|
||||
|
||||
var (
|
||||
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
|
||||
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
|
||||
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
|
||||
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
|
||||
)
|
||||
|
||||
func getBasicAccountsWithResource() *Account {
|
||||
return &Account{
|
||||
Id: accID,
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
accNetResourcePeer1ID: {
|
||||
ID: accNetResourcePeer1ID,
|
||||
AccountID: accID,
|
||||
Key: "peer1Key",
|
||||
IP: accNetResourcePeer1IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
WtVersion: "0.35.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
accNetResourcePeer2ID: {
|
||||
ID: accNetResourcePeer2ID,
|
||||
AccountID: accID,
|
||||
Key: "peer2Key",
|
||||
IP: accNetResourcePeer2IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
WtVersion: "0.34.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
accNetResourceRouter1ID: {
|
||||
ID: accNetResourceRouter1ID,
|
||||
AccountID: accID,
|
||||
Key: "router1Key",
|
||||
IP: accNetResourceRouter1IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
WtVersion: "0.35.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
group1ID: {
|
||||
ID: group1ID,
|
||||
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
|
||||
},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{
|
||||
ID: network1ID,
|
||||
AccountID: accID,
|
||||
Name: "network1",
|
||||
},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{
|
||||
ID: accNetResourceRouter1ID,
|
||||
NetworkID: network1ID,
|
||||
AccountID: accID,
|
||||
Peer: accNetResourceRouter1ID,
|
||||
PeerGroups: []string{},
|
||||
Masquerade: false,
|
||||
Metric: 100,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: accNetResource1ID,
|
||||
AccountID: accID,
|
||||
NetworkID: network1ID,
|
||||
Address: "10.10.10.0/24",
|
||||
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
|
||||
Type: resourceTypes.NetworkResourceType("subnet"),
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policy1ID",
|
||||
AccountID: accID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule1ID",
|
||||
Enabled: true,
|
||||
Sources: []string{group1ID},
|
||||
DestinationResource: Resource{
|
||||
ID: accNetResource1ID,
|
||||
Type: "Host",
|
||||
},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"80"},
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: nil,
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: accNetResourceRestrictPostureCheckID,
|
||||
Name: accNetResourceRestrictPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.35.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceRelaxedPostureCheckID,
|
||||
Name: accNetResourceRelaxedPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceLockedPostureCheckID,
|
||||
Name: accNetResourceLockedPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "7.7.7",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceLinuxPostureCheckID,
|
||||
Name: accNetResourceLinuxPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
OSVersionCheck: &posture.OSVersionCheck{
|
||||
Linux: &posture.MinKernelVersionCheck{
|
||||
MinKernelVersion: "0.0.0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// all peers should match the policy
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should allow peer1 to match the policy
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should not match any peer
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 0, "expected rules count don't match")
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should allow peer1 to match the policy
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||
|
||||
// should allow peer1 and peer2 to match the policy
|
||||
newPolicy := &Policy{
|
||||
ID: "policy2ID",
|
||||
AccountID: accID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "policy2ID",
|
||||
Enabled: true,
|
||||
Sources: []string{group1ID},
|
||||
DestinationResource: Resource{
|
||||
ID: accNetResource1ID,
|
||||
Type: "Host",
|
||||
},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
|
||||
}
|
||||
|
||||
account.Policies = append(account.Policies, newPolicy)
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 2, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
|
||||
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
|
||||
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// two posture checks should match only the peers that match both checks
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
|
||||
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||
ID: "router2Id",
|
||||
NetworkID: network1ID,
|
||||
AccountID: accID,
|
||||
Peer: "router2Id",
|
||||
})
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
}
|
||||
|
||||
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Holder struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*Account
|
||||
}
|
||||
|
||||
func NewHolder() *Holder {
|
||||
return &Holder{
|
||||
accounts: make(map[string]*Account),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Holder) GetAccount(id string) *Account {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.accounts[id]
|
||||
}
|
||||
|
||||
func (h *Holder) AddAccount(account *Account) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
a := h.accounts[account.Id]
|
||||
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
|
||||
return
|
||||
}
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.accounts[id] = account
|
||||
return account, nil
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
|
||||
if a.NetworkMapCache != nil {
|
||||
return
|
||||
}
|
||||
a.nmapInitOnce.Do(func() {
|
||||
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
|
||||
func (a *Account) GetPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
accountZones []*zones.Zone,
|
||||
validatedPeers map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
|
||||
}
|
||||
|
||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
|
||||
a.initNetworkMapBuilder(validatedPeers)
|
||||
}
|
||||
@@ -1,592 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
if components == nil {
|
||||
t.Fatal("GetPeerNetworkMapComponents returned nil")
|
||||
}
|
||||
|
||||
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
||||
|
||||
if newNetworkMap == nil {
|
||||
t.Fatal("CalculateNetworkMapFromComponents returned nil")
|
||||
}
|
||||
|
||||
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
|
||||
|
||||
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
||||
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
|
||||
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
|
||||
componentsJSON, err := json.MarshalIndent(components, "", " ")
|
||||
require.NoError(t, err, "error marshaling components to JSON")
|
||||
|
||||
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
goldenDir := filepath.Join("testdata", "comparison")
|
||||
err = os.MkdirAll(goldenDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
|
||||
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
|
||||
require.NoError(t, err, "error writing legacy golden file")
|
||||
|
||||
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
|
||||
err = os.WriteFile(newGoldenPath, newJSON, 0644)
|
||||
require.NoError(t, err, "error writing components golden file")
|
||||
|
||||
componentsPath := filepath.Join(goldenDir, "components.json")
|
||||
err = os.WriteFile(componentsPath, componentsJSON, 0644)
|
||||
require.NoError(t, err, "error writing components golden file")
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON),
|
||||
"NetworkMaps from legacy and components approaches do not match.\n"+
|
||||
"Legacy JSON saved to: %s\n"+
|
||||
"Components JSON saved to: %s",
|
||||
legacyGoldenPath, newGoldenPath)
|
||||
|
||||
t.Logf("✅ NetworkMaps are identical")
|
||||
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
|
||||
t.Logf(" Components NetworkMap: %s", newGoldenPath)
|
||||
}
|
||||
|
||||
func normalizeAndSortNetworkMap(nm *NetworkMap) {
|
||||
if nm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(nm.Peers, func(i, j int) bool {
|
||||
return nm.Peers[i].ID < nm.Peers[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
|
||||
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
|
||||
})
|
||||
|
||||
sort.Slice(nm.Routes, func(i, j int) bool {
|
||||
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
|
||||
})
|
||||
|
||||
sort.Slice(nm.FirewallRules, func(i, j int) bool {
|
||||
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
|
||||
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
|
||||
}
|
||||
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
|
||||
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
|
||||
}
|
||||
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
|
||||
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
|
||||
}
|
||||
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
|
||||
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
|
||||
}
|
||||
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
|
||||
})
|
||||
|
||||
for i := range nm.RoutesFirewallRules {
|
||||
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
|
||||
}
|
||||
|
||||
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
|
||||
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
|
||||
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
|
||||
}
|
||||
|
||||
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
|
||||
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
|
||||
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
|
||||
}
|
||||
for k := 0; k < minLen; k++ {
|
||||
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
|
||||
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
|
||||
}
|
||||
}
|
||||
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
|
||||
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
|
||||
}
|
||||
|
||||
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
|
||||
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
|
||||
}
|
||||
|
||||
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
|
||||
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
|
||||
}
|
||||
|
||||
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
|
||||
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
|
||||
}
|
||||
|
||||
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
|
||||
})
|
||||
|
||||
if nm.DNSConfig.CustomZones != nil {
|
||||
for i := range nm.DNSConfig.CustomZones {
|
||||
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
|
||||
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(nm.DNSConfig.NameServerGroups) != 0 {
|
||||
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
|
||||
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
|
||||
t.Helper()
|
||||
|
||||
if legacy.Network.Serial != current.Network.Serial {
|
||||
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
|
||||
}
|
||||
|
||||
if len(legacy.Peers) != len(current.Peers) {
|
||||
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
|
||||
}
|
||||
|
||||
legacyPeerIDs := make(map[string]bool)
|
||||
for _, p := range legacy.Peers {
|
||||
legacyPeerIDs[p.ID] = true
|
||||
}
|
||||
|
||||
for _, p := range current.Peers {
|
||||
if !legacyPeerIDs[p.ID] {
|
||||
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
|
||||
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
|
||||
}
|
||||
|
||||
if len(legacy.FirewallRules) != len(current.FirewallRules) {
|
||||
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
|
||||
}
|
||||
|
||||
if len(legacy.Routes) != len(current.Routes) {
|
||||
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
|
||||
}
|
||||
|
||||
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
|
||||
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
|
||||
}
|
||||
|
||||
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
|
||||
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
numPeers = 100
|
||||
devGroupID = "group-dev"
|
||||
opsGroupID = "group-ops"
|
||||
allGroupID = "group-all"
|
||||
routeID = route.ID("route-main")
|
||||
routeHA1ID = route.ID("route-ha-1")
|
||||
routeHA2ID = route.ID("route-ha-2")
|
||||
policyIDDevOps = "policy-dev-ops"
|
||||
policyIDAll = "policy-all"
|
||||
policyIDPosture = "policy-posture"
|
||||
policyIDDrop = "policy-drop"
|
||||
postureCheckID = "posture-check-ver"
|
||||
networkResourceID = "res-database"
|
||||
networkID = "net-database"
|
||||
networkRouterID = "router-database"
|
||||
nameserverGroupID = "ns-group-main"
|
||||
testingPeerID = "peer-60"
|
||||
expiredPeerID = "peer-98"
|
||||
offlinePeerID = "peer-99"
|
||||
routingPeerID = "peer-95"
|
||||
testAccountID = "account-comparison-test"
|
||||
)
|
||||
|
||||
func createTestAccount() *Account {
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
||||
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
ip := net.IP{100, 64, 0, byte(i + 1)}
|
||||
wtVersion := "0.25.0"
|
||||
if i%2 == 0 {
|
||||
wtVersion = "0.40.0"
|
||||
}
|
||||
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
|
||||
}
|
||||
|
||||
if peerID == expiredPeerID {
|
||||
p.LoginExpirationEnabled = true
|
||||
pastTimestamp := time.Now().Add(-2 * time.Hour)
|
||||
p.LastLogin = &pastTimestamp
|
||||
}
|
||||
|
||||
peers[peerID] = p
|
||||
allGroupPeers = append(allGroupPeers, peerID)
|
||||
if i < numPeers/2 {
|
||||
devGroupPeers = append(devGroupPeers, peerID)
|
||||
} else {
|
||||
opsGroupPeers = append(opsGroupPeers, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
groups := map[string]*Group{
|
||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
||||
}
|
||||
|
||||
policies := []*Policy{
|
||||
{
|
||||
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
||||
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
|
||||
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
|
||||
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
|
||||
SourcePostureChecks: []string{postureCheckID},
|
||||
Rules: []*PolicyRule{{
|
||||
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
||||
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
routes := map[route.ID]*route.Route{
|
||||
routeID: {
|
||||
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
|
||||
Peer: peers["peer-75"].Key,
|
||||
PeerID: "peer-75",
|
||||
Description: "Route to internal resource", Enabled: true,
|
||||
PeerGroups: []string{devGroupID, opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
},
|
||||
routeHA1ID: {
|
||||
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
Peer: peers["peer-80"].Key,
|
||||
PeerID: "peer-80",
|
||||
Description: "HA Route 1", Enabled: true, Metric: 1000,
|
||||
PeerGroups: []string{allGroupID},
|
||||
Groups: []string{allGroupID},
|
||||
AccessControlGroups: []string{allGroupID},
|
||||
},
|
||||
routeHA2ID: {
|
||||
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
Peer: peers["peer-90"].Key,
|
||||
PeerID: "peer-90",
|
||||
Description: "HA Route 2", Enabled: true, Metric: 900,
|
||||
PeerGroups: []string{devGroupID, opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{allGroupID},
|
||||
},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||
Network: &Network{
|
||||
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||
},
|
||||
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
nameserverGroupID: {
|
||||
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
|
||||
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
|
||||
},
|
||||
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
|
||||
},
|
||||
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
|
||||
}
|
||||
|
||||
for _, p := range account.Policies {
|
||||
p.AccountID = account.Id
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
r.AccountID = account.Id
|
||||
}
|
||||
|
||||
return account
|
||||
}
|
||||
|
||||
func BenchmarkLegacyNetworkMap(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = account.GetPeerNetworkMap(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
nil,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkComponentsNetworkMap(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
_ = CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkComponentsCreation(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCalculationFromComponents(b *testing.B) {
|
||||
account := createTestAccount()
|
||||
ctx := context.Background()
|
||||
peerID := testingPeerID
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
pid := fmt.Sprintf("peer-%d", i)
|
||||
if pid != offlinePeerID {
|
||||
validatedPeersMap[pid] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
peersCustomZone := nbdns.CustomZone{}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
components := account.GetPeerNetworkMapComponents(
|
||||
ctx,
|
||||
peerID,
|
||||
peersCustomZone,
|
||||
nil,
|
||||
validatedPeersMap,
|
||||
resourcePolicies,
|
||||
routers,
|
||||
groupIDToUserIDs,
|
||||
)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
}
|
||||
@@ -19,8 +19,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
|
||||
|
||||
type NetworkMapComponents struct {
|
||||
PeerID string
|
||||
|
||||
|
||||
787
management/server/types/networkmap_components_test.go
Normal file
787
management/server/types/networkmap_components_test.go
Normal file
@@ -0,0 +1,787 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap {
|
||||
t.Helper()
|
||||
return account.GetPeerNetworkMapFromComponents(
|
||||
context.Background(),
|
||||
peerID,
|
||||
account.GetPeersCustomZone(context.Background(), "netbird.io"),
|
||||
nil,
|
||||
validatedPeers,
|
||||
account.GetResourcePoliciesMap(),
|
||||
account.GetResourceRoutersMap(),
|
||||
nil,
|
||||
account.GetActiveGroupUsers(),
|
||||
)
|
||||
}
|
||||
|
||||
func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} {
|
||||
excludeSet := make(map[string]struct{}, len(excludePeerIDs))
|
||||
for _, id := range excludePeerIDs {
|
||||
excludeSet[id] = struct{}{}
|
||||
}
|
||||
validated := make(map[string]struct{}, len(account.Peers))
|
||||
for id := range account.Peers {
|
||||
if _, excluded := excludeSet[id]; !excluded {
|
||||
validated[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return validated
|
||||
}
|
||||
|
||||
func peerIDs(peers []*nbpeer.Peer) []string {
|
||||
ids := make([]string, len(peers))
|
||||
for i, p := range peers {
|
||||
ids[i] = p.ID
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
assert.NotNil(t, nm)
|
||||
assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy")
|
||||
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy")
|
||||
assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself")
|
||||
assert.Empty(t, nm.OfflinePeers, "no expired peers expected")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-intra-src", Name: "src <-> src", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-src"}, Destinations: []string{"group-src"},
|
||||
}},
|
||||
})
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_FirewallRules(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated")
|
||||
|
||||
var hasAcceptAll bool
|
||||
for _, rule := range nm.FirewallRules {
|
||||
if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) {
|
||||
hasAcceptAll = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_LoginExpiration(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
account.Settings.PeerLoginExpirationEnabled = true
|
||||
account.Settings.PeerLoginExpiration = 1 * time.Hour
|
||||
|
||||
expiredTime := time.Now().Add(-2 * time.Hour)
|
||||
account.Peers["peer-dst-1"].LoginExpirationEnabled = true
|
||||
account.Peers["peer-dst-1"].LastLogin = &expiredTime
|
||||
|
||||
validated := allPeersValidated(account)
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers")
|
||||
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account, "peer-dst-1")
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded")
|
||||
assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account, "peer-src-1")
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map")
|
||||
assert.Empty(t, nm.FirewallRules)
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasResourceRoute bool
|
||||
for _, r := range nm.Routes {
|
||||
if r.Network.String() == "10.200.0.1/32" {
|
||||
hasResourceRoute = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router")
|
||||
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||
|
||||
var hasResourceRoute bool
|
||||
for _, r := range nm.Routes {
|
||||
if r.Network.String() == "10.200.0.1/32" {
|
||||
hasResourceRoute = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasResourceRoute, "router peer should receive network resource route")
|
||||
assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||
|
||||
for _, r := range nm.Routes {
|
||||
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||
}},
|
||||
}
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id,
|
||||
SourcePostureChecks: []string{"pc-version"},
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Sources: []string{"group-src"},
|
||||
DestinationResource: types.Resource{ID: "resource-guarded"},
|
||||
}},
|
||||
})
|
||||
|
||||
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||
ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true,
|
||||
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32",
|
||||
})
|
||||
account.Networks = append(account.Networks, &networkTypes.Network{
|
||||
ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id,
|
||||
})
|
||||
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||
ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
|
||||
})
|
||||
|
||||
t.Run("peer passes posture check", func(t *testing.T) {
|
||||
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasGuardedRoute bool
|
||||
for _, r := range nm.Routes {
|
||||
if r.Network.String() == "10.200.1.1/32" {
|
||||
hasGuardedRoute = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route")
|
||||
})
|
||||
|
||||
t.Run("peer fails posture check", func(t *testing.T) {
|
||||
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
for _, r := range nm.Routes {
|
||||
assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||
}},
|
||||
{ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{
|
||||
OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}},
|
||||
}},
|
||||
}
|
||||
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id,
|
||||
SourcePostureChecks: []string{"pc-version", "pc-os"},
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Sources: []string{"group-src"},
|
||||
DestinationResource: types.Resource{ID: "resource-strict"},
|
||||
}},
|
||||
})
|
||||
|
||||
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||
ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true,
|
||||
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32",
|
||||
})
|
||||
account.Networks = append(account.Networks, &networkTypes.Network{
|
||||
ID: "net-strict", Name: "Strict Net", AccountID: account.Id,
|
||||
})
|
||||
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||
ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
|
||||
})
|
||||
|
||||
t.Run("passes both posture checks", func(t *testing.T) {
|
||||
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
|
||||
account.Peers["peer-src-1"].Meta.GoOS = "linux"
|
||||
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var found bool
|
||||
for _, r := range nm.Routes {
|
||||
if r.Network.String() == "10.200.2.1/32" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "peer passing both checks should get resource route")
|
||||
})
|
||||
|
||||
t.Run("fails version posture check", func(t *testing.T) {
|
||||
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
|
||||
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
for _, r := range nm.Routes {
|
||||
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails OS posture check", func(t *testing.T) {
|
||||
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
|
||||
account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0"
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
for _, r := range nm.Routes {
|
||||
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||
|
||||
var resourceFWRules []*types.RouteFirewallRule
|
||||
for _, rule := range nm.RoutesFirewallRules {
|
||||
if rule.Destination == "10.200.0.1/32" {
|
||||
resourceFWRules = append(resourceFWRules, rule)
|
||||
}
|
||||
}
|
||||
assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource")
|
||||
|
||||
var hasSourcePeerIP bool
|
||||
for _, rule := range resourceFWRules {
|
||||
for _, sr := range rule.SourceRanges {
|
||||
if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" {
|
||||
hasSourcePeerIP = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_DNSManagement(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
t.Run("peer in DNS-enabled group", func(t *testing.T) {
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled")
|
||||
})
|
||||
|
||||
t.Run("peer in DNS-disabled group", func(t *testing.T) {
|
||||
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||
assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NameServerGroups(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
assert.True(t, nm.DNSConfig.ServiceEnable)
|
||||
|
||||
var hasNSGroup bool
|
||||
for _, ns := range nm.DNSConfig.NameServerGroups {
|
||||
if ns.ID == "ns-main" {
|
||||
hasNSGroup = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.Routes["route-ha-1"] = &route.Route{
|
||||
ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
|
||||
Enabled: true, Metric: 100, AccountID: account.Id,
|
||||
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
|
||||
}
|
||||
account.Routes["route-ha-2"] = &route.Route{
|
||||
ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
|
||||
Enabled: true, Metric: 200, AccountID: account.Id,
|
||||
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"},
|
||||
}
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
haCount := 0
|
||||
for _, r := range nm.Routes {
|
||||
if r.Network.String() == "172.16.0.0/16" {
|
||||
haCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.Routes["route-acl"] = &route.Route{
|
||||
ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"),
|
||||
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
|
||||
Enabled: true, Metric: 100, AccountID: account.Id,
|
||||
Groups: []string{"group-dst"},
|
||||
PeerGroups: []string{"group-src"},
|
||||
AccessControlGroups: []string{"group-dst"},
|
||||
}
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasFWRule bool
|
||||
for _, rule := range nm.RoutesFirewallRules {
|
||||
if rule.Destination == "192.168.100.0/24" {
|
||||
hasFWRule = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.Routes["route-open"] = &route.Route{
|
||||
ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
|
||||
Enabled: true, Metric: 100, AccountID: account.Id,
|
||||
Groups: []string{"group-src"},
|
||||
PeerGroups: []string{"group-src"},
|
||||
}
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasFWRule bool
|
||||
for _, rule := range nm.RoutesFirewallRules {
|
||||
if rule.Destination == "10.99.0.0/16" {
|
||||
hasFWRule = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.Peers["peer-dst-1"].SSHEnabled = true
|
||||
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-ssh", Name: "SSH to dst", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||
}},
|
||||
})
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||
assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
for _, p := range account.Policies {
|
||||
p.Enabled = false
|
||||
}
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers")
|
||||
assert.Empty(t, nm.FirewallRules)
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
for _, r := range account.Routes {
|
||||
r.Enabled = false
|
||||
}
|
||||
for _, r := range account.NetworkResources {
|
||||
r.Enabled = false
|
||||
}
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
assert.Empty(t, nm.Routes, "disabled routes should not appear in network map")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
for _, r := range account.NetworkResources {
|
||||
r.Enabled = false
|
||||
}
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||
|
||||
for _, r := range nm.Routes {
|
||||
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||
|
||||
assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy")
|
||||
assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_DropPolicy(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-drop", Name: "Drop src->dst", Enabled: true,
|
||||
Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP,
|
||||
Ports: []string{"5432"},
|
||||
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||
}},
|
||||
})
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasDropRule bool
|
||||
for _, rule := range nm.FirewallRules {
|
||||
if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" {
|
||||
hasDropRule = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasDropRule, "drop policy should generate drop firewall rule")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_PortRangePolicy(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0"
|
||||
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-range", Name: "Range rule", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
|
||||
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
|
||||
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||
}},
|
||||
})
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasRangeRule bool
|
||||
for _, rule := range nm.FirewallRules {
|
||||
if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 {
|
||||
hasRangeRule = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||
ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
|
||||
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32",
|
||||
})
|
||||
account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"},
|
||||
Resources: []types.Resource{{ID: "resource-2"}},
|
||||
}
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-res2", Name: "Access Resource 2", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Sources: []string{"group-src"},
|
||||
DestinationResource: types.Resource{ID: "resource-2"},
|
||||
}},
|
||||
})
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||
|
||||
resourceRouteCount := 0
|
||||
for _, r := range nm.Routes {
|
||||
if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" {
|
||||
resourceRouteCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||
ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
|
||||
Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com",
|
||||
})
|
||||
account.Groups["group-res-domain"] = &types.Group{
|
||||
ID: "group-res-domain", Name: "Domain Resource Group",
|
||||
Resources: []types.Resource{{ID: "resource-domain"}},
|
||||
}
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-domain", Name: "Access domain resource", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Sources: []string{"group-src"},
|
||||
DestinationResource: types.Resource{ID: "resource-domain"},
|
||||
}},
|
||||
})
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||
|
||||
var hasDomainRoute bool
|
||||
for _, r := range nm.Routes {
|
||||
if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" {
|
||||
hasDomainRoute = true
|
||||
}
|
||||
}
|
||||
assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource")
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_NetworkEmpty(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
nm := networkMapFromComponents(t, account, "nonexistent-peer", validated)
|
||||
|
||||
assert.NotNil(t, nm)
|
||||
assert.Empty(t, nm.Peers)
|
||||
assert.Empty(t, nm.FirewallRules)
|
||||
assert.NotNil(t, nm.Network)
|
||||
}
|
||||
|
||||
func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) {
|
||||
account := createComponentTestAccount()
|
||||
validated := allPeersValidated(account)
|
||||
|
||||
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||
ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true,
|
||||
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32",
|
||||
})
|
||||
account.Networks = append(account.Networks, &networkTypes.Network{
|
||||
ID: "net-other", Name: "Other Net", AccountID: account.Id,
|
||||
})
|
||||
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||
ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id,
|
||||
})
|
||||
account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group",
|
||||
Resources: []types.Resource{{ID: "resource-other"}},
|
||||
}
|
||||
account.Policies = append(account.Policies, &types.Policy{
|
||||
ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-other", Name: "Other resource access", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Sources: []string{"group-src"},
|
||||
DestinationResource: types.Resource{ID: "resource-other"},
|
||||
}},
|
||||
})
|
||||
|
||||
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||
|
||||
for _, r := range nm.Routes {
|
||||
assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources")
|
||||
}
|
||||
}
|
||||
|
||||
func createComponentTestAccount() *types.Account {
|
||||
peers := map[string]*nbpeer.Peer{
|
||||
"peer-src-1": {
|
||||
ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||
},
|
||||
"peer-src-2": {
|
||||
ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||
},
|
||||
"peer-dst-1": {
|
||||
ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||
},
|
||||
"peer-router-1": {
|
||||
ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||
},
|
||||
}
|
||||
|
||||
groups := map[string]*types.Group{
|
||||
"group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}},
|
||||
"group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}},
|
||||
"group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}},
|
||||
"group-res": {
|
||||
ID: "group-res", Name: "Resource Group",
|
||||
Resources: []types.Resource{{ID: "resource-1"}},
|
||||
},
|
||||
}
|
||||
|
||||
policies := []*types.Policy{
|
||||
{
|
||||
ID: "policy-base", Name: "Base connectivity", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-base", Name: "Allow src <-> dst", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: "policy-resource", Name: "Network resource access", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: "rule-resource", Name: "Source -> Resource", Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||
Sources: []string{"group-src"},
|
||||
DestinationResource: types.Resource{ID: "resource-1"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
routes := map[route.ID]*route.Route{
|
||||
"route-main": {
|
||||
ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"),
|
||||
Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
|
||||
Enabled: true, Metric: 100,
|
||||
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
|
||||
},
|
||||
}
|
||||
|
||||
users := map[string]*types.User{
|
||||
"user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}},
|
||||
"user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}},
|
||||
}
|
||||
|
||||
account := &types.Account{
|
||||
Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||
Users: users,
|
||||
Network: &types.Network{
|
||||
Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||
},
|
||||
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
"ns-main": {
|
||||
ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"},
|
||||
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true,
|
||||
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32",
|
||||
},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"},
|
||||
},
|
||||
Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour},
|
||||
}
|
||||
|
||||
for _, p := range account.Policies {
|
||||
p.AccountID = account.Id
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
r.AccountID = account.Id
|
||||
}
|
||||
|
||||
return account
|
||||
}
|
||||
@@ -1,967 +0,0 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
numPeers = 100
|
||||
devGroupID = "group-dev"
|
||||
opsGroupID = "group-ops"
|
||||
allGroupID = "group-all"
|
||||
sshUsersGroupID = "group-ssh-users"
|
||||
routeID = route.ID("route-main")
|
||||
routeHA1ID = route.ID("route-ha-1")
|
||||
routeHA2ID = route.ID("route-ha-2")
|
||||
policyIDDevOps = "policy-dev-ops"
|
||||
policyIDAll = "policy-all"
|
||||
policyIDPosture = "policy-posture"
|
||||
policyIDDrop = "policy-drop"
|
||||
policyIDSSH = "policy-ssh"
|
||||
postureCheckID = "posture-check-ver"
|
||||
networkResourceID = "res-database"
|
||||
networkID = "net-database"
|
||||
networkRouterID = "router-database"
|
||||
nameserverGroupID = "ns-group-main"
|
||||
testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map.
|
||||
expiredPeerID = "peer-98" // This peer will be online but with an expired session.
|
||||
offlinePeerID = "peer-99" // This peer will be completely offline.
|
||||
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
|
||||
testAccountID = "account-golden-test"
|
||||
userAdminID = "user-admin"
|
||||
userDevID = "user-dev"
|
||||
userOpsID = "user-ops"
|
||||
)
|
||||
|
||||
func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
account := createTestAccountWithEntities()
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
var peerIDs []string
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
b.ResetTimer()
|
||||
b.Run("new builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
for _, peerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newPeerID := "peer-new-101"
|
||||
newPeerIP := net.IP{100, 64, 1, 1}
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: newPeerID,
|
||||
IP: newPeerIP,
|
||||
Key: fmt.Sprintf("key-%s", newPeerID),
|
||||
DNSLabel: "peernew101",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newPeerID] = newPeer
|
||||
|
||||
if devGroup, exists := account.Groups[devGroupID]; exists {
|
||||
devGroup.Peers = append(devGroup.Peers, newPeerID)
|
||||
}
|
||||
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newPeerID)
|
||||
}
|
||||
|
||||
validatedPeersMap[newPeerID] = struct{}{}
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
require.NoError(t, err, "error adding peer to cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
account := createTestAccountWithEntities()
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
var peerIDs []string
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
newPeerID := "peer-new-101"
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: newPeerID,
|
||||
IP: net.IP{100, 64, 1, 1},
|
||||
Key: fmt.Sprintf("key-%s", newPeerID),
|
||||
DNSLabel: "peernew101",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
}
|
||||
|
||||
account.Peers[newPeerID] = newPeer
|
||||
account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID)
|
||||
account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID)
|
||||
validatedPeersMap[newPeerID] = struct{}{}
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("new builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newRouterID := "peer-new-router-102"
|
||||
newRouterIP := net.IP{100, 64, 1, 2}
|
||||
newRouter := &nbpeer.Peer{
|
||||
ID: newRouterID,
|
||||
IP: newRouterIP,
|
||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
||||
DNSLabel: "newrouter102",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newRouterID] = newRouter
|
||||
|
||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
||||
}
|
||||
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID("route-new-router"),
|
||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
||||
Peer: newRouter.Key,
|
||||
PeerID: newRouterID,
|
||||
Description: "Route from new router",
|
||||
Enabled: true,
|
||||
PeerGroups: []string{opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Routes[newRoute.ID] = newRoute
|
||||
|
||||
validatedPeersMap[newRouterID] = struct{}{}
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
require.NoError(t, err, "error adding router to cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
account := createTestAccountWithEntities()
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
var peerIDs []string
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
newRouterID := "peer-new-router-102"
|
||||
newRouterIP := net.IP{100, 64, 1, 2}
|
||||
newRouter := &nbpeer.Peer{
|
||||
ID: newRouterID,
|
||||
IP: newRouterIP,
|
||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
||||
DNSLabel: "newrouter102",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newRouterID] = newRouter
|
||||
|
||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
||||
}
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID("route-new-router"),
|
||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
||||
Peer: newRouter.Key,
|
||||
PeerID: newRouterID,
|
||||
Description: "Route from new router",
|
||||
Enabled: true,
|
||||
PeerGroups: []string{opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Routes[newRoute.ID] = newRoute
|
||||
|
||||
validatedPeersMap[newRouterID] = struct{}{}
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("new builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
deletedPeerID := "peer-25"
|
||||
|
||||
delete(account.Peers, deletedPeerID)
|
||||
|
||||
if devGroup, exists := account.Groups[devGroupID]; exists {
|
||||
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
|
||||
return id == deletedPeerID
|
||||
})
|
||||
}
|
||||
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
|
||||
return id == deletedPeerID
|
||||
})
|
||||
}
|
||||
|
||||
delete(validatedPeersMap, deletedPeerID)
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
err = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
require.NoError(t, err, "error deleting peer from cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
deletedRouterID := "peer-75"
|
||||
|
||||
var affectedRoute *route.Route
|
||||
for _, r := range account.Routes {
|
||||
if r.PeerID == deletedRouterID {
|
||||
affectedRoute = r
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, affectedRoute, "Router peer should have a route")
|
||||
|
||||
for _, group := range account.Groups {
|
||||
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
|
||||
return id == deletedRouterID
|
||||
})
|
||||
}
|
||||
|
||||
for routeID, r := range account.Routes {
|
||||
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
|
||||
delete(account.Routes, routeID)
|
||||
}
|
||||
}
|
||||
delete(account.Peers, deletedRouterID)
|
||||
delete(validatedPeersMap, deletedRouterID)
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
err = builder.OnPeerDeleted(account, deletedRouterID)
|
||||
require.NoError(t, err, "error deleting routing peer from cache")
|
||||
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
|
||||
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
account := createTestAccountWithEntities()
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
var peerIDs []string
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
|
||||
deletedPeerID := "peer-25"
|
||||
|
||||
delete(account.Peers, deletedPeerID)
|
||||
account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool {
|
||||
return id == deletedPeerID
|
||||
})
|
||||
account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool {
|
||||
return id == deletedPeerID
|
||||
})
|
||||
delete(validatedPeersMap, deletedPeerID)
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
b.Run("new builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
|
||||
for _, peer := range networkMap.Peers {
|
||||
if peer.Status != nil {
|
||||
peer.Status.LastSeen = time.Time{}
|
||||
}
|
||||
peer.LastLogin = &time.Time{}
|
||||
}
|
||||
for _, peer := range networkMap.OfflinePeers {
|
||||
if peer.Status != nil {
|
||||
peer.Status.LastSeen = time.Time{}
|
||||
}
|
||||
peer.LastLogin = &time.Time{}
|
||||
}
|
||||
|
||||
sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID })
|
||||
sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID })
|
||||
sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID })
|
||||
|
||||
sort.Slice(networkMap.FirewallRules, func(i, j int) bool {
|
||||
r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j]
|
||||
if r1.PeerIP != r2.PeerIP {
|
||||
return r1.PeerIP < r2.PeerIP
|
||||
}
|
||||
if r1.Protocol != r2.Protocol {
|
||||
return r1.Protocol < r2.Protocol
|
||||
}
|
||||
if r1.Direction != r2.Direction {
|
||||
return r1.Direction < r2.Direction
|
||||
}
|
||||
if r1.Action != r2.Action {
|
||||
return r1.Action < r2.Action
|
||||
}
|
||||
return r1.Port < r2.Port
|
||||
})
|
||||
|
||||
sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool {
|
||||
r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j]
|
||||
if r1.RouteID != r2.RouteID {
|
||||
return r1.RouteID < r2.RouteID
|
||||
}
|
||||
if r1.Action != r2.Action {
|
||||
return r1.Action < r2.Action
|
||||
}
|
||||
if r1.Destination != r2.Destination {
|
||||
return r1.Destination < r2.Destination
|
||||
}
|
||||
if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 {
|
||||
if r1.SourceRanges[0] != r2.SourceRanges[0] {
|
||||
return r1.SourceRanges[0] < r2.SourceRanges[0]
|
||||
}
|
||||
}
|
||||
return r1.Port < r2.Port
|
||||
})
|
||||
|
||||
for _, ranges := range networkMap.RoutesFirewallRules {
|
||||
sort.Slice(ranges.SourceRanges, func(i, j int) bool {
|
||||
return ranges.SourceRanges[i] < ranges.SourceRanges[j]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type networkMapJSON struct {
|
||||
Peers []*nbpeer.Peer `json:"Peers"`
|
||||
Network *types.Network `json:"Network"`
|
||||
Routes []*route.Route `json:"Routes"`
|
||||
DNSConfig dns.Config `json:"DNSConfig"`
|
||||
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
|
||||
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
|
||||
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
|
||||
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
|
||||
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
|
||||
EnableSSH bool `json:"EnableSSH"`
|
||||
}
|
||||
|
||||
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
|
||||
result := &networkMapJSON{
|
||||
Peers: nm.Peers,
|
||||
Network: nm.Network,
|
||||
Routes: nm.Routes,
|
||||
DNSConfig: nm.DNSConfig,
|
||||
OfflinePeers: nm.OfflinePeers,
|
||||
FirewallRules: nm.FirewallRules,
|
||||
RoutesFirewallRules: nm.RoutesFirewallRules,
|
||||
ForwardingRules: nm.ForwardingRules,
|
||||
EnableSSH: nm.EnableSSH,
|
||||
}
|
||||
|
||||
if len(nm.AuthorizedUsers) > 0 {
|
||||
result.AuthorizedUsers = make(map[string][]string)
|
||||
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
|
||||
for localUser := range nm.AuthorizedUsers {
|
||||
localUsers = append(localUsers, localUser)
|
||||
}
|
||||
sort.Strings(localUsers)
|
||||
|
||||
for _, localUser := range localUsers {
|
||||
userIDs := nm.AuthorizedUsers[localUser]
|
||||
sortedUserIDs := make([]string, 0, len(userIDs))
|
||||
for userID := range userIDs {
|
||||
sortedUserIDs = append(sortedUserIDs, userID)
|
||||
}
|
||||
sort.Strings(sortedUserIDs)
|
||||
result.AuthorizedUsers[localUser] = sortedUserIDs
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func createTestAccountWithEntities() *types.Account {
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
||||
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
ip := net.IP{100, 64, 0, byte(i + 1)}
|
||||
wtVersion := "0.25.0"
|
||||
if i%2 == 0 {
|
||||
wtVersion = "0.40.0"
|
||||
}
|
||||
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
|
||||
}
|
||||
|
||||
if peerID == expiredPeerID {
|
||||
p.LoginExpirationEnabled = true
|
||||
pastTimestamp := time.Now().Add(-2 * time.Hour)
|
||||
p.LastLogin = &pastTimestamp
|
||||
}
|
||||
|
||||
peers[peerID] = p
|
||||
allGroupPeers = append(allGroupPeers, peerID)
|
||||
if i < numPeers/2 {
|
||||
devGroupPeers = append(devGroupPeers, peerID)
|
||||
} else {
|
||||
opsGroupPeers = append(opsGroupPeers, peerID)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
groups := map[string]*types.Group{
|
||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
||||
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
|
||||
}
|
||||
|
||||
policies := []*types.Policy{
|
||||
{
|
||||
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
|
||||
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false,
|
||||
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
|
||||
Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
|
||||
SourcePostureChecks: []string{postureCheckID},
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
|
||||
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
routes := map[route.ID]*route.Route{
|
||||
routeID: {
|
||||
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
|
||||
Peer: peers["peer-75"].Key,
|
||||
PeerID: "peer-75",
|
||||
Description: "Route to internal resource", Enabled: true,
|
||||
PeerGroups: []string{devGroupID, opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
},
|
||||
routeHA1ID: {
|
||||
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
Peer: peers["peer-80"].Key,
|
||||
PeerID: "peer-80",
|
||||
Description: "HA Route 1", Enabled: true, Metric: 1000,
|
||||
PeerGroups: []string{allGroupID},
|
||||
Groups: []string{allGroupID},
|
||||
AccessControlGroups: []string{allGroupID},
|
||||
},
|
||||
routeHA2ID: {
|
||||
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
Peer: peers["peer-90"].Key,
|
||||
PeerID: "peer-90",
|
||||
Description: "HA Route 2", Enabled: true, Metric: 900,
|
||||
PeerGroups: []string{devGroupID, opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{allGroupID},
|
||||
},
|
||||
}
|
||||
|
||||
users := map[string]*types.User{
|
||||
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
|
||||
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
|
||||
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
|
||||
}
|
||||
|
||||
account := &types.Account{
|
||||
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||
Users: users,
|
||||
Network: &types.Network{
|
||||
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||
},
|
||||
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
|
||||
NameServerGroups: map[string]*dns.NameServerGroup{
|
||||
nameserverGroupID: {
|
||||
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
|
||||
NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}},
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||
}},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
|
||||
},
|
||||
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
|
||||
},
|
||||
Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
|
||||
}
|
||||
|
||||
for _, p := range account.Policies {
|
||||
p.AccountID = account.Id
|
||||
}
|
||||
for _, r := range account.Routes {
|
||||
r.AccountID = account.Id
|
||||
}
|
||||
|
||||
return account
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newRouterID := "peer-new-router-102"
|
||||
newRouterIP := net.IP{100, 64, 1, 2}
|
||||
newRouter := &nbpeer.Peer{
|
||||
ID: newRouterID,
|
||||
IP: newRouterIP,
|
||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
||||
DNSLabel: "newrouter102",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newRouterID] = newRouter
|
||||
|
||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
||||
}
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID("route-new-router"),
|
||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
||||
Peer: newRouter.Key,
|
||||
PeerID: newRouterID,
|
||||
Description: "Route from new router",
|
||||
Enabled: true,
|
||||
PeerGroups: []string{opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Routes[newRoute.ID] = newRoute
|
||||
|
||||
validatedPeersMap[newRouterID] = struct{}{}
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
||||
|
||||
t.Log("Update golden file with OnPeerAdded router...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestSetSessionCookieHasRootPath(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
setSessionCookie(w, "test-token", time.Hour)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
require.Len(t, cookies, 1)
|
||||
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
@@ -3425,6 +3425,17 @@ components:
|
||||
description: Display name for the admin user (defaults to email if not provided)
|
||||
type: string
|
||||
example: Admin User
|
||||
create_pat:
|
||||
description: If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled.
|
||||
type: boolean
|
||||
example: true
|
||||
pat_expire_in:
|
||||
description: Expiration of the Personal Access Token in days. Defaults to 1 day when omitted.
|
||||
type: integer
|
||||
minimum: 1
|
||||
maximum: 365
|
||||
default: 1
|
||||
example: 30
|
||||
required:
|
||||
- email
|
||||
- password
|
||||
@@ -3441,6 +3452,10 @@ components:
|
||||
description: Email address of the created user
|
||||
type: string
|
||||
example: admin@example.com
|
||||
personal_access_token:
|
||||
description: Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server.
|
||||
type: string
|
||||
example: nbp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
required:
|
||||
- user_id
|
||||
- email
|
||||
@@ -4979,7 +4994,10 @@ paths:
|
||||
/api/setup:
|
||||
post:
|
||||
summary: Setup Instance
|
||||
description: Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled).
|
||||
description: |
|
||||
Creates the initial admin user for the instance. This endpoint does not require authentication but only works when setup is required (no accounts exist and embedded IDP is enabled).
|
||||
|
||||
When the management server is started with `NB_SETUP_PAT_ENABLED=true` and the request includes `create_pat: true`, the endpoint also provisions the NetBird account for the new owner user and returns the plain text Personal Access Token in `personal_access_token`. The optional `pat_expire_in` value defaults to 1 day when omitted. If any post-user step fails the Dex user is rolled back and setup remains retryable.
|
||||
tags: [ Instance ]
|
||||
security: [ ]
|
||||
requestBody:
|
||||
|
||||
@@ -4294,6 +4294,9 @@ type SetupKeyRequest struct {
|
||||
|
||||
// SetupRequest Request to set up the initial admin user
|
||||
type SetupRequest struct {
|
||||
// CreatePat If true and the server has setup-time PAT issuance enabled (NB_SETUP_PAT_ENABLED=true), create a Personal Access Token for the new owner user and return it in the response. Ignored when the server feature is disabled.
|
||||
CreatePat *bool `json:"create_pat,omitempty"`
|
||||
|
||||
// Email Email address for the admin user
|
||||
Email string `json:"email"`
|
||||
|
||||
@@ -4302,6 +4305,9 @@ type SetupRequest struct {
|
||||
|
||||
// Password Password for the admin user (minimum 8 characters)
|
||||
Password string `json:"password"`
|
||||
|
||||
// PatExpireIn Expiration of the Personal Access Token in days. Defaults to 1 day when omitted.
|
||||
PatExpireIn *int `json:"pat_expire_in,omitempty"`
|
||||
}
|
||||
|
||||
// SetupResponse Response after successful instance setup
|
||||
@@ -4309,6 +4315,9 @@ type SetupResponse struct {
|
||||
// Email Email address of the created user
|
||||
Email string `json:"email"`
|
||||
|
||||
// PersonalAccessToken Plain text Personal Access Token created during setup. Present only when create_pat was requested and the NB_SETUP_PAT_ENABLED feature was enabled on the server.
|
||||
PersonalAccessToken *string `json:"personal_access_token,omitempty"`
|
||||
|
||||
// UserId The ID of the created user
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user