mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-28 21:26:40 +00:00
Compare commits
18 Commits
ci/freebsd
...
feat/byod-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49f170c21e | ||
|
|
f78bc6113e | ||
|
|
4f0d6ef8f9 | ||
|
|
e6a663ba20 | ||
|
|
7c4011d8e2 | ||
|
|
154b81645a | ||
|
|
34167c8a16 | ||
|
|
8fe2b5ec1e | ||
|
|
e62521132c | ||
|
|
de3cb06067 | ||
|
|
4fdc39c8f8 | ||
|
|
94149a9441 | ||
|
|
38fd73fad6 | ||
|
|
9dd76b5a07 | ||
|
|
0b5380a7dc | ||
|
|
177171e437 | ||
|
|
da57b0f276 | ||
|
|
26ba03f08e |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.3"
|
SIGN_PIPE_VER: "v0.1.4"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,11 +15,9 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -58,13 +55,6 @@ type Controller struct {
|
|||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
|
||||||
holder *types.Holder
|
|
||||||
|
|
||||||
expNewNetworkMap bool
|
|
||||||
expNewNetworkMapAIDs map[string]struct{}
|
|
||||||
|
|
||||||
compactedNetworkMap bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -81,29 +71,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
|
||||||
newNetworkMapBuilder = false
|
|
||||||
}
|
|
||||||
|
|
||||||
compactedNetworkMap := true
|
|
||||||
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
|
||||||
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
|
||||||
if err != nil && len(compactedEnv) > 0 {
|
|
||||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
|
||||||
}
|
|
||||||
if err == nil && !parsedCompactedNmap {
|
|
||||||
log.WithContext(ctx).Info("disabling compacted mode")
|
|
||||||
compactedNetworkMap = false
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
|
||||||
expIDs := make(map[string]struct{}, len(ids))
|
|
||||||
for _, id := range ids {
|
|
||||||
expIDs[id] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Controller{
|
return &Controller{
|
||||||
repo: newRepository(store),
|
repo: newRepository(store),
|
||||||
metrics: nMetrics,
|
metrics: nMetrics,
|
||||||
@@ -117,12 +84,6 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
|
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
EphemeralPeersManager: ephemeralPeersManager,
|
EphemeralPeersManager: ephemeralPeersManager,
|
||||||
|
|
||||||
holder: types.NewHolder(),
|
|
||||||
expNewNetworkMap: newNetworkMapBuilder,
|
|
||||||
expNewNetworkMapAIDs: expIDs,
|
|
||||||
|
|
||||||
compactedNetworkMap: compactedNetworkMap,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,17 +114,9 @@ func (c *Controller) CountStreams() int {
|
|||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
var (
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
account *types.Account
|
if err != nil {
|
||||||
err error
|
return fmt.Errorf("failed to get account: %v", err)
|
||||||
)
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
|
||||||
} else {
|
|
||||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get account: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
globalStart := time.Now()
|
globalStart := time.Now()
|
||||||
@@ -197,10 +150,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||||
@@ -243,16 +192,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
|
||||||
switch {
|
|
||||||
case c.experimentalNetworkMap(accountID):
|
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
|
||||||
case c.compactedNetworkMap:
|
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
|
||||||
default:
|
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
@@ -318,10 +258,6 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
|||||||
// UpdatePeers updates all peers that belong to an account.
|
// UpdatePeers updates all peers that belong to an account.
|
||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
|
||||||
return fmt.Errorf("recalculate network map cache: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -371,16 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
|
||||||
switch {
|
|
||||||
case c.experimentalNetworkMap(accountId):
|
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
|
||||||
case c.compactedNetworkMap:
|
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
|
||||||
default:
|
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -451,17 +378,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return peer, emptyMap, nil, 0, nil
|
return peer, emptyMap, nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
account *types.Account
|
if err != nil {
|
||||||
err error
|
return nil, nil, nil, 0, err
|
||||||
)
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
|
||||||
} else {
|
|
||||||
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, 0, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
account.InjectProxyPolicies(ctx)
|
||||||
@@ -493,20 +412,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var networkMap *types.NetworkMap
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
if c.experimentalNetworkMap(accountID) {
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
} else {
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
if c.compactedNetworkMap {
|
|
||||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
|
||||||
} else {
|
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -518,108 +427,6 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
|
||||||
c.enrichAccountFromHolder(account)
|
|
||||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getPeerNetworkMapExp(
|
|
||||||
ctx context.Context,
|
|
||||||
accountId string,
|
|
||||||
peerId string,
|
|
||||||
validatedPeers map[string]struct{},
|
|
||||||
peersCustomZone nbdns.CustomZone,
|
|
||||||
accountZones []*zones.Zone,
|
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
|
||||||
) *types.NetworkMap {
|
|
||||||
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
|
||||||
if account == nil {
|
|
||||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
|
||||||
return &types.NetworkMap{
|
|
||||||
Network: &types.Network{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
|
||||||
c.enrichAccountFromHolder(account)
|
|
||||||
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
|
||||||
c.enrichAccountFromHolder(account)
|
|
||||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
|
||||||
account := c.getAccountFromHolder(accountId)
|
|
||||||
if account == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account.UpdatePeerInNetworkMapCache(peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
|
||||||
account.RecalculateNetworkMapCache(validatedPeers)
|
|
||||||
c.updateAccountInHolder(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
|
||||||
if c.experimentalNetworkMap(accountId) {
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.recalculateNetworkMapCache(account, validatedPeers)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
|
||||||
_, ok := c.expNewNetworkMapAIDs[accountId]
|
|
||||||
return c.expNewNetworkMap || ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
|
||||||
a := c.holder.GetAccount(account.Id)
|
|
||||||
if a == nil {
|
|
||||||
c.holder.AddAccount(account)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account.NetworkMapCache = a.NetworkMapCache
|
|
||||||
if account.NetworkMapCache == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.holder.AddAccount(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
|
||||||
return c.holder.GetAccount(accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
|
||||||
a := c.holder.GetAccount(accountID)
|
|
||||||
if a != nil {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return account
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
|
||||||
c.holder.AddAccount(account)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDNSDomain returns the configured dnsDomain
|
// GetDNSDomain returns the configured dnsDomain
|
||||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
@@ -756,16 +563,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range peers {
|
|
||||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
@@ -775,14 +573,6 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
|||||||
|
|
||||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
|
||||||
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
|
||||||
}
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -817,19 +607,6 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
MessageType: network_map.MessageTypeNetworkMap,
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
@@ -872,21 +649,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var networkMap *types.NetworkMap
|
account.InjectProxyPolicies(ctx)
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
if c.experimentalNetworkMap(peer.AccountID) {
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
} else {
|
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
account.InjectProxyPolicies(ctx)
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
if c.compactedNetworkMap {
|
|
||||||
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
|
||||||
} else {
|
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
|
|||||||
@@ -12,9 +12,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
|
||||||
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
|
||||||
|
|
||||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||||
OldForwarderPort = nbdns.ForwarderClientPort
|
OldForwarderPort = nbdns.ForwarderClientPort
|
||||||
DnsForwarderPortMinVersion = "v0.59.0"
|
DnsForwarderPortMinVersion = "v0.59.0"
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ type store interface {
|
|||||||
|
|
||||||
type proxyManager interface {
|
type proxyManager interface {
|
||||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
|
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
|||||||
var ret []*domain.Domain
|
var ret []*domain.Domain
|
||||||
|
|
||||||
// Add connected proxy clusters as free domains.
|
// Add connected proxy clusters as free domains.
|
||||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
|
||||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the target cluster is in the available clusters
|
// Verify the target cluster is in the available clusters for this account
|
||||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||||
}
|
}
|
||||||
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
|
|||||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
allowList, err := m.getClusterAllowList(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||||
}
|
}
|
||||||
@@ -298,6 +299,17 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
|
|||||||
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
|
||||||
|
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
|
||||||
|
}
|
||||||
|
if len(byopAddresses) > 0 {
|
||||||
|
return byopAddresses, nil
|
||||||
|
}
|
||||||
|
return m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
|
||||||
bestCluster := ""
|
bestCluster := ""
|
||||||
bestLen := -1
|
bestLen := -1
|
||||||
|
|||||||
@@ -0,0 +1,110 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockProxyManager struct {
|
||||||
|
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||||
|
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||||
|
if m.getActiveClusterAddressesFunc != nil {
|
||||||
|
return m.getActiveClusterAddressesFunc(ctx)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||||
|
if m.getActiveClusterAddressesForAccountFunc != nil {
|
||||||
|
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClusterAllowList_BYOPProxy(t *testing.T) {
|
||||||
|
pm := &mockProxyManager{
|
||||||
|
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||||
|
assert.Equal(t, "acc-123", accID)
|
||||||
|
return []string{"byop.example.com"}, nil
|
||||||
|
},
|
||||||
|
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||||
|
t.Fatal("should not call GetActiveClusterAddresses when BYOP addresses exist")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := Manager{proxyManager: pm}
|
||||||
|
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"byop.example.com"}, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
|
||||||
|
pm := &mockProxyManager{
|
||||||
|
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||||
|
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := Manager{proxyManager: pm}
|
||||||
|
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
|
||||||
|
pm := &mockProxyManager{
|
||||||
|
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||||
|
return nil, errors.New("db error")
|
||||||
|
},
|
||||||
|
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||||
|
t.Fatal("should not call GetActiveClusterAddresses when BYOP lookup fails")
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := Manager{proxyManager: pm}
|
||||||
|
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, result)
|
||||||
|
assert.Contains(t, err.Error(), "BYOP cluster addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
|
||||||
|
pm := &mockProxyManager{
|
||||||
|
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
|
||||||
|
return []string{}, nil
|
||||||
|
},
|
||||||
|
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
|
||||||
|
return []string{"eu.proxy.netbird.io"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := Manager{proxyManager: pm}
|
||||||
|
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
|
||||||
|
}
|
||||||
|
|
||||||
@@ -11,15 +11,19 @@ import (
|
|||||||
|
|
||||||
// Manager defines the interface for proxy operations
|
// Manager defines the interface for proxy operations
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error
|
||||||
Disconnect(ctx context.Context, proxyID string) error
|
Disconnect(ctx context.Context, proxyID string) error
|
||||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
|
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
|
||||||
|
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
|
||||||
|
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||||
|
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||||
|
|||||||
@@ -13,13 +13,19 @@ import (
|
|||||||
// store defines the interface for proxy persistence operations
|
// store defines the interface for proxy persistence operations
|
||||||
type store interface {
|
type store interface {
|
||||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||||
|
DisconnectProxy(ctx context.Context, proxyID string) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
|
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
|
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||||
|
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||||
|
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager handles all proxy operations
|
// Manager handles all proxy operations
|
||||||
@@ -43,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
|||||||
|
|
||||||
// Connect registers a new proxy connection in the database.
|
// Connect registers a new proxy connection in the database.
|
||||||
// capabilities may be nil for old proxies that do not report them.
|
// capabilities may be nil for old proxies that do not report them.
|
||||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
func (m *Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var caps proxy.Capabilities
|
var caps proxy.Capabilities
|
||||||
if capabilities != nil {
|
if capabilities != nil {
|
||||||
@@ -53,9 +59,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
|||||||
ID: proxyID,
|
ID: proxyID,
|
||||||
ClusterAddress: clusterAddress,
|
ClusterAddress: clusterAddress,
|
||||||
IPAddress: ipAddress,
|
IPAddress: ipAddress,
|
||||||
|
AccountID: accountID,
|
||||||
LastSeen: now,
|
LastSeen: now,
|
||||||
ConnectedAt: &now,
|
ConnectedAt: &now,
|
||||||
Status: "connected",
|
Status: proxy.StatusConnected,
|
||||||
Capabilities: caps,
|
Capabilities: caps,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,16 +81,8 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect marks a proxy as disconnected in the database
|
// Disconnect marks a proxy as disconnected in the database
|
||||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
func (m *Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||||
now := time.Now()
|
if err := m.store.DisconnectProxy(ctx, proxyID); err != nil {
|
||||||
p := &proxy.Proxy{
|
|
||||||
ID: proxyID,
|
|
||||||
Status: "disconnected",
|
|
||||||
DisconnectedAt: &now,
|
|
||||||
LastSeen: now,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -96,7 +95,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat updates the proxy's last seen timestamp
|
// Heartbeat updates the proxy's last seen timestamp
|
||||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (m *Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||||
return err
|
return err
|
||||||
@@ -108,7 +107,7 @@ func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddre
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||||
@@ -117,16 +116,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
|
|||||||
return addresses, nil
|
return addresses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
|
||||||
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
|
||||||
clusters, err := m.store.GetActiveProxyClusters(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return clusters, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||||
@@ -146,10 +135,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||||
|
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return addresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||||
|
return m.store.GetProxyByAccountID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||||
|
return m.store.CountProxiesByAccountID(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||||
|
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return !conflicting, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||||
|
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,336 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStore struct {
|
||||||
|
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
|
||||||
|
disconnectProxyFunc func(ctx context.Context, proxyID string) error
|
||||||
|
updateProxyHeartbeatFunc func(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
|
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
|
||||||
|
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
|
||||||
|
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
|
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
|
||||||
|
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||||
|
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||||
|
if m.saveProxyFunc != nil {
|
||||||
|
return m.saveProxyFunc(ctx, p)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||||
|
if m.disconnectProxyFunc != nil {
|
||||||
|
return m.disconnectProxyFunc(ctx, proxyID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
|
if m.updateProxyHeartbeatFunc != nil {
|
||||||
|
return m.updateProxyHeartbeatFunc(ctx, proxyID, clusterAddress, ipAddress)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||||
|
if m.getActiveProxyClusterAddressesFunc != nil {
|
||||||
|
return m.getActiveProxyClusterAddressesFunc(ctx)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||||
|
if m.getActiveProxyClusterAddressesForAccFunc != nil {
|
||||||
|
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||||
|
if m.cleanupStaleProxiesFunc != nil {
|
||||||
|
return m.cleanupStaleProxiesFunc(ctx, d)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||||
|
if m.getProxyByAccountIDFunc != nil {
|
||||||
|
return m.getProxyByAccountIDFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||||
|
}
|
||||||
|
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||||
|
if m.countProxiesByAccountIDFunc != nil {
|
||||||
|
return m.countProxiesByAccountIDFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||||
|
if m.isClusterAddressConflictingFunc != nil {
|
||||||
|
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||||
|
if m.deleteAccountClusterFunc != nil {
|
||||||
|
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestManager(s store) *Manager {
|
||||||
|
meter := noop.NewMeterProvider().Meter("test")
|
||||||
|
m, err := NewManager(s, meter)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnect_WithAccountID(t *testing.T) {
|
||||||
|
accountID := "acc-123"
|
||||||
|
|
||||||
|
var savedProxy *proxy.Proxy
|
||||||
|
s := &mockStore{
|
||||||
|
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||||
|
savedProxy = p
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotNil(t, savedProxy)
|
||||||
|
assert.Equal(t, "proxy-1", savedProxy.ID)
|
||||||
|
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
|
||||||
|
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
|
||||||
|
assert.Equal(t, &accountID, savedProxy.AccountID)
|
||||||
|
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||||
|
assert.NotNil(t, savedProxy.ConnectedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnect_WithoutAccountID(t *testing.T) {
|
||||||
|
var savedProxy *proxy.Proxy
|
||||||
|
s := &mockStore{
|
||||||
|
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
|
||||||
|
savedProxy = p
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
err := mgr.Connect(context.Background(), "proxy-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NotNil(t, savedProxy)
|
||||||
|
assert.Nil(t, savedProxy.AccountID)
|
||||||
|
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnect_StoreError(t *testing.T) {
|
||||||
|
s := &mockStore{
|
||||||
|
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
|
||||||
|
return errors.New("db error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
err := mgr.Connect(context.Background(), "proxy-1", "cluster.example.com", "10.0.0.1", nil, nil)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsClusterAddressAvailable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
conflicting bool
|
||||||
|
storeErr error
|
||||||
|
wantResult bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "available - no conflict",
|
||||||
|
conflicting: false,
|
||||||
|
wantResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not available - conflict exists",
|
||||||
|
conflicting: true,
|
||||||
|
wantResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "store error",
|
||||||
|
storeErr: errors.New("db error"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := &mockStore{
|
||||||
|
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
|
||||||
|
return tt.conflicting, tt.storeErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.wantResult, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCountAccountProxies(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
count int64
|
||||||
|
storeErr error
|
||||||
|
wantCount int64
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no proxies",
|
||||||
|
count: 0,
|
||||||
|
wantCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one proxy",
|
||||||
|
count: 1,
|
||||||
|
wantCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "store error",
|
||||||
|
storeErr: errors.New("db error"),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := &mockStore{
|
||||||
|
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
|
||||||
|
return tt.count, tt.storeErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.wantCount, count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccountProxy(t *testing.T) {
|
||||||
|
accountID := "acc-123"
|
||||||
|
|
||||||
|
t.Run("found", func(t *testing.T) {
|
||||||
|
expected := &proxy.Proxy{
|
||||||
|
ID: "proxy-1",
|
||||||
|
ClusterAddress: "byop.example.com",
|
||||||
|
AccountID: &accountID,
|
||||||
|
Status: proxy.StatusConnected,
|
||||||
|
}
|
||||||
|
s := &mockStore{
|
||||||
|
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
|
||||||
|
assert.Equal(t, accountID, accID)
|
||||||
|
return expected, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
p, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, p)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not found", func(t *testing.T) {
|
||||||
|
s := &mockStore{
|
||||||
|
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||||
|
return nil, errors.New("not found")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
_, err := mgr.GetAccountProxy(context.Background(), accountID)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteAccountCluster(t *testing.T) {
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
var deletedCluster, deletedAccount string
|
||||||
|
s := &mockStore{
|
||||||
|
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
|
||||||
|
deletedCluster = clusterAddress
|
||||||
|
deletedAccount = accountID
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "cluster.example.com", deletedCluster)
|
||||||
|
assert.Equal(t, "acc-123", deletedAccount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("store error", func(t *testing.T) {
|
||||||
|
s := &mockStore{
|
||||||
|
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
|
||||||
|
return errors.New("db error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
|
||||||
|
expected := []string{"byop.example.com"}
|
||||||
|
s := &mockStore{
|
||||||
|
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
|
||||||
|
assert.Equal(t, "acc-123", accID)
|
||||||
|
return expected, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := newTestManager(s)
|
||||||
|
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
}
|
||||||
@@ -93,17 +93,17 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect mocks base method.
|
// Connect mocks base method.
|
||||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect indicates an expected call of Connect.
|
// Connect indicates an expected call of Connect.
|
||||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, accountID, capabilities)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect mocks base method.
|
// Disconnect mocks base method.
|
||||||
@@ -135,19 +135,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusters mocks base method.
|
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||||
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
|
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
|
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
|
||||||
ret0, _ := ret[0].([]Cluster)
|
ret0, _ := ret[0].([]string)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat mocks base method.
|
// Heartbeat mocks base method.
|
||||||
@@ -164,6 +162,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAdd
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountProxy mocks base method.
|
||||||
|
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(*Proxy)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountProxy indicates an expected call of GetAccountProxy.
|
||||||
|
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountAccountProxies mocks base method.
|
||||||
|
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(int64)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountAccountProxies indicates an expected call of CountAccountProxies.
|
||||||
|
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClusterAddressAvailable mocks base method.
|
||||||
|
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
|
||||||
|
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster mocks base method.
|
||||||
|
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// MockController is a mock of Controller interface.
|
// MockController is a mock of Controller interface.
|
||||||
type MockController struct {
|
type MockController struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusConnected = "connected"
|
||||||
|
StatusDisconnected = "disconnected"
|
||||||
|
)
|
||||||
|
|
||||||
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
||||||
// Nil fields mean the proxy never reported this capability.
|
// Nil fields mean the proxy never reported this capability.
|
||||||
@@ -20,6 +27,7 @@ type Proxy struct {
|
|||||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||||
IPAddress string `gorm:"type:varchar(45)"`
|
IPAddress string `gorm:"type:varchar(45)"`
|
||||||
|
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
|
||||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||||
ConnectedAt *time.Time
|
ConnectedAt *time.Time
|
||||||
DisconnectedAt *time.Time
|
DisconnectedAt *time.Time
|
||||||
@@ -35,6 +43,8 @@ func (Proxy) TableName() string {
|
|||||||
|
|
||||||
// Cluster represents a group of proxy nodes serving the same address.
|
// Cluster represents a group of proxy nodes serving the same address.
|
||||||
type Cluster struct {
|
type Cluster struct {
|
||||||
|
ID string
|
||||||
Address string
|
Address string
|
||||||
ConnectedProxies int
|
ConnectedProxies int
|
||||||
|
SelfHosted bool
|
||||||
}
|
}
|
||||||
|
|||||||
195
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
195
management/internals/modules/reverseproxy/proxytoken/handler.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package proxytoken
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
store store.Store
|
||||||
|
permissionsManager permissions.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
|
||||||
|
h := &handler{store: s, permissionsManager: permissionsManager}
|
||||||
|
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req api.ProxyTokenRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Name == "" || len(req.Name) > 255 {
|
||||||
|
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiresIn time.Duration
|
||||||
|
if req.ExpiresIn != nil {
|
||||||
|
if *req.ExpiresIn < 0 {
|
||||||
|
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if *req.ExpiresIn > 0 {
|
||||||
|
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := userAuth.AccountId
|
||||||
|
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
|
||||||
|
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toProxyTokenCreatedResponse(generated)
|
||||||
|
util.WriteJSONObject(r.Context(), w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := make([]api.ProxyToken, 0, len(tokens))
|
||||||
|
for _, token := range tokens {
|
||||||
|
resp = append(resp, toProxyTokenResponse(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenID := mux.Vars(r)["tokenId"]
|
||||||
|
if tokenID == "" {
|
||||||
|
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||||
|
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||||
|
} else {
|
||||||
|
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
|
||||||
|
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
|
||||||
|
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
|
||||||
|
resp := api.ProxyToken{
|
||||||
|
Id: token.ID,
|
||||||
|
Name: token.Name,
|
||||||
|
Revoked: token.Revoked,
|
||||||
|
}
|
||||||
|
if !token.CreatedAt.IsZero() {
|
||||||
|
resp.CreatedAt = token.CreatedAt
|
||||||
|
}
|
||||||
|
if token.ExpiresAt != nil {
|
||||||
|
resp.ExpiresAt = token.ExpiresAt
|
||||||
|
}
|
||||||
|
if token.LastUsed != nil {
|
||||||
|
resp.LastUsed = token.LastUsed
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
|
||||||
|
base := toProxyTokenResponse(&generated.ProxyAccessToken)
|
||||||
|
plainToken := string(generated.PlainToken)
|
||||||
|
return api.ProxyTokenCreated{
|
||||||
|
Id: base.Id,
|
||||||
|
Name: base.Name,
|
||||||
|
CreatedAt: base.CreatedAt,
|
||||||
|
ExpiresAt: base.ExpiresAt,
|
||||||
|
LastUsed: base.LastUsed,
|
||||||
|
Revoked: base.Revoked,
|
||||||
|
PlainToken: plainToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,275 @@
|
|||||||
|
package proxytoken
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/shared/auth"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func authContext(accountID, userID string) context.Context {
|
||||||
|
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
|
||||||
|
AccountId: accountID,
|
||||||
|
UserId: userID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateToken_AccountScoped(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
accountID := "acc-123"
|
||||||
|
var savedToken *types.ProxyAccessToken
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||||
|
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||||
|
savedToken = token
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
store: mockStore,
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
body := `{"name": "my-token"}`
|
||||||
|
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||||
|
req = req.WithContext(authContext(accountID, "user-1"))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.createToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp api.ProxyTokenCreated
|
||||||
|
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||||
|
|
||||||
|
assert.NotEmpty(t, resp.PlainToken)
|
||||||
|
assert.Equal(t, "my-token", resp.Name)
|
||||||
|
assert.False(t, resp.Revoked)
|
||||||
|
|
||||||
|
require.NotNil(t, savedToken)
|
||||||
|
require.NotNil(t, savedToken.AccountID)
|
||||||
|
assert.Equal(t, accountID, *savedToken.AccountID)
|
||||||
|
assert.Equal(t, "user-1", savedToken.CreatedBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateToken_WithExpiration(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
var savedToken *types.ProxyAccessToken
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
|
||||||
|
func(_ context.Context, token *types.ProxyAccessToken) error {
|
||||||
|
savedToken = token
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
store: mockStore,
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
body := `{"name": "expiring-token", "expires_in": 3600}`
|
||||||
|
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||||
|
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.createToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
require.NotNil(t, savedToken)
|
||||||
|
require.NotNil(t, savedToken.ExpiresAt)
|
||||||
|
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateToken_EmptyName(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
body := `{"name": ""}`
|
||||||
|
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||||
|
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.createToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateToken_PermissionDenied(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
body := `{"name": "test"}`
|
||||||
|
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
|
||||||
|
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.createToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListTokens(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
accountID := "acc-123"
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
|
||||||
|
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
|
||||||
|
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
store: mockStore,
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
|
||||||
|
req = req.WithContext(authContext(accountID, "user-1"))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.listTokens(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var resp []api.ProxyToken
|
||||||
|
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
|
||||||
|
require.Len(t, resp, 2)
|
||||||
|
assert.Equal(t, "tok-1", resp[0].Id)
|
||||||
|
assert.False(t, resp[0].Revoked)
|
||||||
|
assert.Equal(t, "tok-2", resp[1].Id)
|
||||||
|
assert.True(t, resp[1].Revoked)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevokeToken_Success(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
accountID := "acc-123"
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||||
|
ID: "tok-1",
|
||||||
|
Name: "test-token",
|
||||||
|
AccountID: &accountID,
|
||||||
|
}, nil)
|
||||||
|
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
store: mockStore,
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||||
|
req = req.WithContext(authContext(accountID, "user-1"))
|
||||||
|
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.revokeToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevokeToken_WrongAccount(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
otherAccount := "acc-other"
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||||
|
ID: "tok-1",
|
||||||
|
AccountID: &otherAccount,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
store: mockStore,
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||||
|
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||||
|
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.revokeToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevokeToken_ManagementWideToken(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
mockStore := store.NewMockStore(ctrl)
|
||||||
|
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
|
||||||
|
ID: "tok-1",
|
||||||
|
AccountID: nil,
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
permsMgr := permissions.NewMockManager(ctrl)
|
||||||
|
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
|
||||||
|
|
||||||
|
h := &handler{
|
||||||
|
store: mockStore,
|
||||||
|
permissionsManager: permsMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
|
||||||
|
req = req.WithContext(authContext("acc-123", "user-1"))
|
||||||
|
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.revokeToken(w, req)
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||||
|
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
@@ -28,4 +29,5 @@ type Manager interface {
|
|||||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||||
StartExposeReaper(ctx context.Context)
|
StartExposeReaper(ctx context.Context)
|
||||||
|
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster mocks base method.
|
||||||
|
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteService mocks base method.
|
// DeleteService mocks base method.
|
||||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetServiceByDomain mocks base method.
|
||||||
|
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||||
|
ret0, _ := ret[0].(*Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||||
|
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
// GetGlobalServices mocks base method.
|
// GetGlobalServices mocks base method.
|
||||||
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
|
|||||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||||
|
|
||||||
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
|
||||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||||
@@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||||
for _, c := range clusters {
|
for _, c := range clusters {
|
||||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||||
|
Id: c.ID,
|
||||||
Address: c.Address,
|
Address: c.Address,
|
||||||
ConnectedProxies: c.ConnectedProxies,
|
ConnectedProxies: c.ConnectedProxies,
|
||||||
|
SelfHosted: c.SelfHosted,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, apiClusters)
|
util.WriteJSONObject(r.Context(), w, apiClusters)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clusterAddress := mux.Vars(r)["clusterAddress"]
|
||||||
|
if clusterAddress == "" {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
|
}
|
||||||
|
|||||||
@@ -120,7 +120,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.store.GetActiveProxyClusters(ctx)
|
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||||
|
// owned by the account.
|
||||||
|
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||||
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||||
|
if err != nil {
|
||||||
|
return status.NewPermissionValidationError(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return status.NewPermissionDeniedError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||||
@@ -984,6 +998,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
|
|||||||
return services, nil
|
return services, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||||
|
return m.store.GetServiceByDomain(ctx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -712,7 +712,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
|||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1135,7 +1135,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
|||||||
|
|
||||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
|
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||||
|
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
|
|
||||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
proxyService.SetServiceManager(s.ServiceManager())
|
proxyService.SetServiceManager(s.ServiceManager())
|
||||||
proxyService.SetProxyController(s.ServiceProxyController())
|
proxyService.SetProxyController(s.ServiceProxyController())
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -47,6 +48,11 @@ type ProxyOIDCConfig struct {
|
|||||||
KeysLocation string
|
KeysLocation string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProxyTokenChecker checks whether a proxy access token is still valid.
|
||||||
|
type ProxyTokenChecker interface {
|
||||||
|
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyServiceServer implements the ProxyService gRPC server
|
// ProxyServiceServer implements the ProxyService gRPC server
|
||||||
type ProxyServiceServer struct {
|
type ProxyServiceServer struct {
|
||||||
proto.UnimplementedProxyServiceServer
|
proto.UnimplementedProxyServiceServer
|
||||||
@@ -75,6 +81,9 @@ type ProxyServiceServer struct {
|
|||||||
// Store for one-time authentication tokens
|
// Store for one-time authentication tokens
|
||||||
tokenStore *OneTimeTokenStore
|
tokenStore *OneTimeTokenStore
|
||||||
|
|
||||||
|
// Checker for proxy access token validity
|
||||||
|
tokenChecker ProxyTokenChecker
|
||||||
|
|
||||||
// OIDC configuration for proxy authentication
|
// OIDC configuration for proxy authentication
|
||||||
oidcConfig ProxyOIDCConfig
|
oidcConfig ProxyOIDCConfig
|
||||||
|
|
||||||
@@ -90,6 +99,8 @@ const pkceVerifierTTL = 10 * time.Minute
|
|||||||
type proxyConnection struct {
|
type proxyConnection struct {
|
||||||
proxyID string
|
proxyID string
|
||||||
address string
|
address string
|
||||||
|
accountID *string
|
||||||
|
tokenID string
|
||||||
capabilities *proto.ProxyCapabilities
|
capabilities *proto.ProxyCapabilities
|
||||||
stream proto.ProxyService_GetMappingUpdateServer
|
stream proto.ProxyService_GetMappingUpdateServer
|
||||||
sendChan chan *proto.GetMappingUpdateResponse
|
sendChan chan *proto.GetMappingUpdateResponse
|
||||||
@@ -97,8 +108,19 @@ type proxyConnection struct {
|
|||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||||
|
token := GetProxyTokenFromContext(ctx)
|
||||||
|
if token == nil || token.AccountID == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if requestAccountID == "" || *token.AccountID != requestAccountID {
|
||||||
|
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewProxyServiceServer creates a new proxy service server.
|
// NewProxyServiceServer creates a new proxy service server.
|
||||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
accessLogManager: accessLogMgr,
|
accessLogManager: accessLogMgr,
|
||||||
@@ -108,6 +130,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
|||||||
peersManager: peersManager,
|
peersManager: peersManager,
|
||||||
usersManager: usersManager,
|
usersManager: usersManager,
|
||||||
proxyManager: proxyMgr,
|
proxyManager: proxyMgr,
|
||||||
|
tokenChecker: tokenChecker,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
go s.cleanupStaleProxies(ctx)
|
go s.cleanupStaleProxies(ctx)
|
||||||
@@ -166,10 +189,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var accountID *string
|
||||||
|
token := GetProxyTokenFromContext(ctx)
|
||||||
|
if token != nil && token.AccountID != nil {
|
||||||
|
accountID = token.AccountID
|
||||||
|
|
||||||
|
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||||
|
}
|
||||||
|
if !available {
|
||||||
|
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenID string
|
||||||
|
if token != nil {
|
||||||
|
tokenID = token.ID
|
||||||
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithCancel(ctx)
|
connCtx, cancel := context.WithCancel(ctx)
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: proxyAddress,
|
address: proxyAddress,
|
||||||
|
accountID: accountID,
|
||||||
|
tokenID: tokenID,
|
||||||
capabilities: req.GetCapabilities(),
|
capabilities: req.GetCapabilities(),
|
||||||
stream: stream,
|
stream: stream,
|
||||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||||
@@ -177,12 +221,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
|
||||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
|
||||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register proxy in database with capabilities
|
|
||||||
var caps *proxy.Capabilities
|
var caps *proxy.Capabilities
|
||||||
if c := req.GetCapabilities(); c != nil {
|
if c := req.GetCapabilities(); c != nil {
|
||||||
caps = &proxy.Capabilities{
|
caps = &proxy.Capabilities{
|
||||||
@@ -191,19 +229,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, accountID, caps); err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
if accountID != nil {
|
||||||
s.connectedProxies.Delete(proxyID)
|
cancel()
|
||||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
|
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"proxy_id": proxyID,
|
"proxy_id": proxyID,
|
||||||
"address": proxyAddress,
|
"address": proxyAddress,
|
||||||
"cluster_addr": proxyAddress,
|
"cluster_addr": proxyAddress,
|
||||||
|
"account_id": accountID,
|
||||||
"total_proxies": len(s.GetConnectedProxies()),
|
"total_proxies": len(s.GetConnectedProxies()),
|
||||||
}).Info("Proxy registered in cluster")
|
}).Info("Proxy registered in cluster")
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -228,7 +272,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
go s.sender(conn, errChan)
|
go s.sender(conn, errChan)
|
||||||
|
|
||||||
// Start heartbeat goroutine
|
// Start heartbeat goroutine
|
||||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
go s.heartbeat(connCtx, conn, peerInfo)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
@@ -238,16 +282,28 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, ipAddress string) {
|
||||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
if err := s.proxyManager.Heartbeat(ctx, conn.proxyID, conn.address, ipAddress); err != nil {
|
||||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", conn.proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.tokenID != "" && s.tokenChecker != nil {
|
||||||
|
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
|
||||||
|
conn.cancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
@@ -255,8 +311,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
|
|
||||||
// Only entries matching the proxy's cluster address are sent.
|
|
||||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||||
if !isProxyAddressValid(conn.address) {
|
if !isProxyAddressValid(conn.address) {
|
||||||
return fmt.Errorf("proxy address is invalid")
|
return fmt.Errorf("proxy address is invalid")
|
||||||
@@ -289,7 +343,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
|
||||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
var services []*rpservice.Service
|
||||||
|
var err error
|
||||||
|
if conn.accountID != nil {
|
||||||
|
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
|
||||||
|
} else {
|
||||||
|
services, err = s.serviceManager.GetGlobalServices(ctx)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get services from store: %w", err)
|
return nil, fmt.Errorf("get services from store: %w", err)
|
||||||
}
|
}
|
||||||
@@ -318,8 +378,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
|||||||
return mappings, nil
|
return mappings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isProxyAddressValid validates a proxy address
|
// isProxyAddressValid validates a proxy address (domain name or IP address)
|
||||||
func isProxyAddressValid(addr string) bool {
|
func isProxyAddressValid(addr string) bool {
|
||||||
|
if addr == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if net.ParseIP(addr) != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
_, err := domain.ValidateDomains([]string{addr})
|
_, err := domain.ValidateDomains([]string{addr})
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
@@ -343,6 +409,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
|||||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||||
accessLog := req.GetLog()
|
accessLog := req.GetLog()
|
||||||
|
|
||||||
|
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
fields := log.Fields{
|
fields := log.Fields{
|
||||||
"service_id": accessLog.GetServiceId(),
|
"service_id": accessLog.GetServiceId(),
|
||||||
"account_id": accessLog.GetAccountId(),
|
"account_id": accessLog.GetAccountId(),
|
||||||
@@ -380,11 +450,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
|||||||
// Management should call this when services are created/updated/removed.
|
// Management should call this when services are created/updated/removed.
|
||||||
// For create/update operations a unique one-time auth token is generated per
|
// For create/update operations a unique one-time auth token is generated per
|
||||||
// proxy so that every replica can independently authenticate with management.
|
// proxy so that every replica can independently authenticate with management.
|
||||||
|
// BYOP proxies only receive updates for their own account's services.
|
||||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||||
|
updateAccountIDs := make(map[string]struct{})
|
||||||
|
for _, m := range update.Mapping {
|
||||||
|
if m.AccountId != "" {
|
||||||
|
updateAccountIDs[m.AccountId] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||||
conn := value.(*proxyConnection)
|
conn := value.(*proxyConnection)
|
||||||
resp := s.perProxyMessage(update, conn.proxyID)
|
connUpdate := update
|
||||||
|
if conn.accountID != nil && len(updateAccountIDs) > 0 {
|
||||||
|
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
connUpdate = &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: filtered,
|
||||||
|
InitialSyncComplete: update.InitialSyncComplete,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -398,6 +489,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
|
||||||
|
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
|
||||||
|
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||||
|
conn := connVal.(*proxyConnection)
|
||||||
|
conn.cancel()
|
||||||
|
s.connectedProxies.Delete(proxyID)
|
||||||
|
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
|
||||||
|
var filtered []*proto.ProxyMapping
|
||||||
|
for _, m := range mappings {
|
||||||
|
if m.AccountId == accountID {
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
// GetConnectedProxies returns a list of connected proxy IDs
|
// GetConnectedProxies returns a list of connected proxy IDs
|
||||||
func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
||||||
var proxies []string
|
var proxies []string
|
||||||
@@ -466,6 +577,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
conn := connVal.(*proxyConnection)
|
conn := connVal.(*proxyConnection)
|
||||||
|
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !proxyAcceptsMapping(conn, update) {
|
if !proxyAcceptsMapping(conn, update) {
|
||||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||||
continue
|
continue
|
||||||
@@ -549,6 +663,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||||
|
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
|
||||||
@@ -668,6 +786,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
|
|||||||
|
|
||||||
// SendStatusUpdate handles status updates from proxy clients.
|
// SendStatusUpdate handles status updates from proxy clients.
|
||||||
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
|
||||||
|
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
accountID := req.GetAccountId()
|
accountID := req.GetAccountId()
|
||||||
serviceID := req.GetServiceId()
|
serviceID := req.GetServiceId()
|
||||||
protoStatus := req.GetStatus()
|
protoStatus := req.GetStatus()
|
||||||
@@ -738,6 +860,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
|
|||||||
|
|
||||||
// CreateProxyPeer handles proxy peer creation with one-time token authentication
|
// CreateProxyPeer handles proxy peer creation with one-time token authentication
|
||||||
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
|
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
|
||||||
|
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
serviceID := req.GetServiceId()
|
serviceID := req.GetServiceId()
|
||||||
accountID := req.GetAccountId()
|
accountID := req.GetAccountId()
|
||||||
token := req.GetToken()
|
token := req.GetToken()
|
||||||
@@ -792,6 +918,10 @@ func strPtr(s string) *string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
|
||||||
|
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
redirectURL, err := url.Parse(req.GetRedirectUrl())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
|
||||||
@@ -920,21 +1050,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
|
|||||||
|
|
||||||
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
// GenerateSessionToken creates a signed session JWT for the given domain and user.
|
||||||
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
|
||||||
// Find the service by domain to get its signing key
|
service, err := s.getServiceByDomain(ctx, domain)
|
||||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("get services: %w", err)
|
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
|
||||||
}
|
|
||||||
|
|
||||||
var service *rpservice.Service
|
|
||||||
for _, svc := range services {
|
|
||||||
if svc.Domain == domain {
|
|
||||||
service = svc
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if service == nil {
|
|
||||||
return "", fmt.Errorf("service not found for domain: %s", domain)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if service.SessionPrivateKey == "" {
|
if service.SessionPrivateKey == "" {
|
||||||
@@ -1032,6 +1150,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
|
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
@@ -1115,18 +1237,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
|
||||||
services, err := s.serviceManager.GetGlobalServices(ctx)
|
return s.serviceManager.GetServiceByDomain(ctx, domain)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get services: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, service := range services {
|
|
||||||
if service.Domain == domain {
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("service not found for domain: %s", domain)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
|
||||||
|
|||||||
29
management/internals/shared/grpc/proxy_address_test.go
Normal file
29
management/internals/shared/grpc/proxy_address_test.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsProxyAddressValid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
|
||||||
|
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
|
||||||
|
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
|
||||||
|
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
|
||||||
|
{name: "valid IPv6", addr: "::1", valid: true},
|
||||||
|
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
|
||||||
|
{name: "empty string", addr: "", valid: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
|
|||||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
|
|
||||||
// Currently tokens are management-wide; AccountID field is reserved for future use.
|
|
||||||
|
|
||||||
if !token.IsValid() {
|
if !token.IsValid() {
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
|
|||||||
|
|
||||||
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return nil, m.err
|
||||||
|
}
|
||||||
|
for _, services := range m.proxiesByAccount {
|
||||||
|
for _, svc := range services {
|
||||||
|
if svc.Domain == domain {
|
||||||
|
return svc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("service not found for domain: " + domain)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,9 +12,12 @@ import (
|
|||||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
grpcstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -313,6 +316,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
|
|||||||
assert.Contains(t, err.Error(), "invalid state format")
|
assert.Contains(t, err.Error(), "invalid state format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func scopedCtx(accountID string) context.Context {
|
||||||
|
token := &types.ProxyAccessToken{
|
||||||
|
ID: "token-1",
|
||||||
|
AccountID: &accountID,
|
||||||
|
}
|
||||||
|
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func globalCtx() context.Context {
|
||||||
|
token := &types.ProxyAccessToken{
|
||||||
|
ID: "token-global",
|
||||||
|
}
|
||||||
|
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
|
||||||
|
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
|
||||||
|
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
|
||||||
|
require.Error(t, err)
|
||||||
|
st, ok := grpcstatus.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
|
||||||
|
err := enforceAccountScope(scopedCtx("acc-1"), "")
|
||||||
|
require.Error(t, err)
|
||||||
|
st, ok := grpcstatus.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, codes.PermissionDenied, st.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
|
||||||
|
err := enforceAccountScope(globalCtx(), "acc-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = enforceAccountScope(globalCtx(), "acc-2")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = enforceAccountScope(globalCtx(), "")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
|
||||||
|
err := enforceAccountScope(context.Background(), "acc-1")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
|||||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||||
|
|
||||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
|
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||||
proxyService.SetServiceManager(serviceManager)
|
proxyService.SetServiceManager(serviceManager)
|
||||||
|
|
||||||
createTestProxies(t, ctx, testStore)
|
createTestProxies(t, ctx, testStore)
|
||||||
@@ -318,13 +318,17 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
|
|||||||
|
|
||||||
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
|
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||||
|
return m.store.GetServiceByDomain(ctx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type testValidateSessionProxyManager struct{}
|
type testValidateSessionProxyManager struct{}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
|
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *proxy.Capabilities) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,6 +344,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
|
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -348,6 +356,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||||
}
|
}
|
||||||
@@ -1171,11 +1171,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
|||||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||||
}
|
}
|
||||||
@@ -1231,11 +1226,6 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||||
}
|
}
|
||||||
@@ -1274,11 +1264,6 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||||
}
|
}
|
||||||
@@ -1332,11 +1317,6 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||||
}
|
}
|
||||||
@@ -1397,11 +1377,6 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||||
}
|
}
|
||||||
@@ -1633,75 +1608,6 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
|||||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_GetRoutesToSync(t *testing.T) {
|
|
||||||
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
account := &types.Account{
|
|
||||||
Peers: map[string]*nbpeer.Peer{
|
|
||||||
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
|
|
||||||
},
|
|
||||||
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
|
|
||||||
Routes: map[route.ID]*route.Route{
|
|
||||||
"route-1": {
|
|
||||||
ID: "route-1",
|
|
||||||
Network: prefix,
|
|
||||||
NetID: "network-1",
|
|
||||||
Description: "network-1",
|
|
||||||
Peer: "peer-1",
|
|
||||||
NetworkType: 0,
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 999,
|
|
||||||
Enabled: true,
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
"route-2": {
|
|
||||||
ID: "route-2",
|
|
||||||
Network: prefix2,
|
|
||||||
NetID: "network-2",
|
|
||||||
Description: "network-2",
|
|
||||||
Peer: "peer-2",
|
|
||||||
NetworkType: 0,
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 999,
|
|
||||||
Enabled: true,
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
"route-3": {
|
|
||||||
ID: "route-3",
|
|
||||||
Network: prefix,
|
|
||||||
NetID: "network-1",
|
|
||||||
Description: "network-1",
|
|
||||||
Peer: "peer-2",
|
|
||||||
NetworkType: 0,
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 999,
|
|
||||||
Enabled: true,
|
|
||||||
Groups: []string{"group1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
|
|
||||||
|
|
||||||
assert.Len(t, routes, 2)
|
|
||||||
routeIDs := make(map[route.ID]struct{}, 2)
|
|
||||||
for _, r := range routes {
|
|
||||||
routeIDs[r.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
|
||||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
|
||||||
|
|
||||||
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
|
|
||||||
|
|
||||||
assert.Len(t, emptyRoutes, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccount_Copy(t *testing.T) {
|
func TestAccount_Copy(t *testing.T) {
|
||||||
account := &types.Account{
|
account := &types.Account{
|
||||||
Id: "account1",
|
Id: "account1",
|
||||||
@@ -1824,9 +1730,7 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
AccountID: "account1",
|
AccountID: "account1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
|
||||||
}
|
}
|
||||||
account.InitOnce()
|
|
||||||
err := hasNilField(account)
|
err := hasNilField(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -3170,7 +3074,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
@@ -146,6 +147,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
|||||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
|
||||||
|
|
||||||
// Register OAuth callback handler for proxy authentication
|
// Register OAuth callback handler for proxy authentication
|
||||||
if proxyGRPCServer != nil {
|
if proxyGRPCServer != nil {
|
||||||
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
|
||||||
|
|
||||||
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -216,6 +216,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
|||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||||
@@ -389,6 +390,10 @@ func (m *testServiceManager) DeleteService(_ context.Context, _, _, _ string) er
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -435,6 +440,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri
|
|||||||
|
|
||||||
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
|
func (m *testServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||||
|
return m.store.GetServiceByDomain(ctx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||||
}
|
}
|
||||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -238,7 +238,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||||
}
|
}
|
||||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -179,11 +179,6 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
|||||||
testGetNetworkMapGeneral(t)
|
testGetNetworkMapGeneral(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testGetNetworkMapGeneral(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetNetworkMapGeneral(t *testing.T) {
|
func testGetNetworkMapGeneral(t *testing.T) {
|
||||||
manager, _, err := createManager(t)
|
manager, _, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1016,11 +1011,6 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
testUpdateAccountPeers(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateAccountPeers(t *testing.T) {
|
func TestUpdateAccountPeers(t *testing.T) {
|
||||||
testUpdateAccountPeers(t)
|
testUpdateAccountPeers(t)
|
||||||
}
|
}
|
||||||
@@ -1600,7 +1590,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_LoginPeer(t *testing.T) {
|
func Test_LoginPeer(t *testing.T) {
|
||||||
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -1840,11 +1838,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
validatedPeers := make(map[string]struct{})
|
|
||||||
for p := range account.Peers {
|
|
||||||
validatedPeers[p] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("check applied policies for the route", func(t *testing.T) {
|
t.Run("check applied policies for the route", func(t *testing.T) {
|
||||||
route1 := account.Routes["route1"]
|
route1 := account.Routes["route1"]
|
||||||
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
|
||||||
@@ -1858,116 +1851,6 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
|||||||
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
|
||||||
assert.Len(t, policies, 0)
|
assert.Len(t, policies, 0)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("check peer routes firewall rules", func(t *testing.T) {
|
|
||||||
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
|
|
||||||
assert.Len(t, routesFirewallRules, 4)
|
|
||||||
|
|
||||||
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.168.0.0/16",
|
|
||||||
Protocol: "all",
|
|
||||||
Port: 80,
|
|
||||||
RouteID: "route1:peerA",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.168.0.0/16",
|
|
||||||
Protocol: "all",
|
|
||||||
Port: 320,
|
|
||||||
RouteID: "route1:peerA",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
additionalFirewallRule := []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.168.10.0/16",
|
|
||||||
Protocol: "tcp",
|
|
||||||
Port: 80,
|
|
||||||
RouteID: "route4:peerA",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.168.10.0/16",
|
|
||||||
Protocol: "all",
|
|
||||||
RouteID: "route4:peerA",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
|
|
||||||
|
|
||||||
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
|
||||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
|
||||||
assert.Len(t, routesFirewallRules, 2)
|
|
||||||
for _, rule := range expectedRoutesFirewallRules {
|
|
||||||
rule.RouteID = "route1:peerD"
|
|
||||||
}
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
|
||||||
|
|
||||||
// peerE is a single routing peer for route 2 and route 3
|
|
||||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
|
|
||||||
assert.Len(t, routesFirewallRules, 3)
|
|
||||||
|
|
||||||
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: existingNetwork.String(),
|
|
||||||
Protocol: "tcp",
|
|
||||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
|
||||||
RouteID: "route2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
SourceRanges: []string{"0.0.0.0/0"},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.0.2.0/32",
|
|
||||||
Protocol: "all",
|
|
||||||
Domains: domain.List{"example.com"},
|
|
||||||
IsDynamic: true,
|
|
||||||
RouteID: "route3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
SourceRanges: []string{"::/0"},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.0.2.0/32",
|
|
||||||
Protocol: "all",
|
|
||||||
Domains: domain.List{"example.com"},
|
|
||||||
IsDynamic: true,
|
|
||||||
RouteID: "route3",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
|
|
||||||
|
|
||||||
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
|
|
||||||
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
|
||||||
assert.Len(t, routesFirewallRules, 0)
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// orderList is a helper function to sort a list of strings
|
|
||||||
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
|
|
||||||
for _, rule := range ruleList {
|
|
||||||
sort.Strings(rule.SourceRanges)
|
|
||||||
}
|
|
||||||
return ruleList
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteAccountPeersUpdate(t *testing.T) {
|
func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||||
@@ -2665,11 +2548,6 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
validatedPeers := make(map[string]struct{})
|
|
||||||
for p := range account.Peers {
|
|
||||||
validatedPeers[p] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("validate applied policies for different network resources", func(t *testing.T) {
|
t.Run("validate applied policies for different network resources", func(t *testing.T) {
|
||||||
// Test case: Resource1 is directly applied to the policy (policyResource1)
|
// Test case: Resource1 is directly applied to the policy (policyResource1)
|
||||||
policies := account.GetPoliciesForNetworkResource("resource1")
|
policies := account.GetPoliciesForNetworkResource("resource1")
|
||||||
@@ -2693,127 +2571,4 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
|
|||||||
policies = account.GetPoliciesForNetworkResource("resource6")
|
policies = account.GetPoliciesForNetworkResource("resource6")
|
||||||
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
|
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
|
|
||||||
resourcePoliciesMap := account.GetResourcePoliciesMap()
|
|
||||||
resourceRoutersMap := account.GetResourceRoutersMap()
|
|
||||||
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
|
|
||||||
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
|
|
||||||
assert.Len(t, firewallRules, 4)
|
|
||||||
assert.Len(t, sourcePeers, 5)
|
|
||||||
|
|
||||||
expectedFirewallRules := []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.168.0.0/16",
|
|
||||||
Protocol: "all",
|
|
||||||
Port: 80,
|
|
||||||
RouteID: "resource2:peerA",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.168.0.0/16",
|
|
||||||
Protocol: "all",
|
|
||||||
Port: 320,
|
|
||||||
RouteID: "resource2:peerA",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
additionalFirewallRules := []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.0.2.0/32",
|
|
||||||
Protocol: "tcp",
|
|
||||||
Port: 80,
|
|
||||||
Domains: domain.List{"example.com"},
|
|
||||||
IsDynamic: true,
|
|
||||||
RouteID: "resource4:peerA",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
SourceRanges: []string{
|
|
||||||
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
|
|
||||||
},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "192.0.2.0/32",
|
|
||||||
Protocol: "all",
|
|
||||||
Domains: domain.List{"example.com"},
|
|
||||||
IsDynamic: true,
|
|
||||||
RouteID: "resource4:peerA",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
|
|
||||||
|
|
||||||
// peerD is also the routing peer for resource2
|
|
||||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
|
|
||||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
|
|
||||||
assert.Len(t, firewallRules, 2)
|
|
||||||
for _, rule := range expectedFirewallRules {
|
|
||||||
rule.RouteID = "resource2:peerD"
|
|
||||||
}
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
|
||||||
assert.Len(t, sourcePeers, 3)
|
|
||||||
|
|
||||||
// peerE is a single routing peer for resource1 and resource3
|
|
||||||
// PeerE should only receive rules for resource1 since resource3 has no applied policy
|
|
||||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
|
|
||||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
|
|
||||||
assert.Len(t, firewallRules, 1)
|
|
||||||
assert.Len(t, sourcePeers, 2)
|
|
||||||
|
|
||||||
expectedFirewallRules = []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "10.10.10.0/24",
|
|
||||||
Protocol: "tcp",
|
|
||||||
PortRange: types.RulePortRange{Start: 80, End: 350},
|
|
||||||
RouteID: "resource1:peerE",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
|
||||||
|
|
||||||
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
|
|
||||||
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
|
|
||||||
assert.Len(t, firewallRules, 0)
|
|
||||||
|
|
||||||
// peerL is the single routing peer for resource5
|
|
||||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
|
|
||||||
assert.Len(t, routes, 1)
|
|
||||||
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
|
|
||||||
assert.Len(t, firewallRules, 1)
|
|
||||||
assert.Len(t, sourcePeers, 1)
|
|
||||||
|
|
||||||
expectedFirewallRules = []*types.RouteFirewallRule{
|
|
||||||
{
|
|
||||||
SourceRanges: []string{"100.65.29.67/32"},
|
|
||||||
Action: "accept",
|
|
||||||
Destination: "10.12.12.1/32",
|
|
||||||
Protocol: "tcp",
|
|
||||||
Port: 8080,
|
|
||||||
RouteID: "resource5:peerL",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
|
|
||||||
|
|
||||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
|
|
||||||
assert.Len(t, routes, 1)
|
|
||||||
assert.Len(t, sourcePeers, 0)
|
|
||||||
|
|
||||||
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
|
|
||||||
assert.Len(t, routes, 1)
|
|
||||||
assert.Len(t, sourcePeers, 2)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1196,7 +1196,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
|||||||
account.NameServerGroups[ns.ID] = &ns
|
account.NameServerGroups[ns.ID] = &ns
|
||||||
}
|
}
|
||||||
account.NameServerGroupsG = nil
|
account.NameServerGroupsG = nil
|
||||||
account.InitOnce()
|
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1635,7 +1634,6 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
|||||||
if sExtraIntegratedValidatorGroups.Valid {
|
if sExtraIntegratedValidatorGroups.Valid {
|
||||||
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
|
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
|
||||||
}
|
}
|
||||||
account.InitOnce()
|
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -4497,6 +4495,47 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []*types.ProxyAccessToken
|
||||||
|
result := tx.Where("account_id = ?", accountID).Find(&tokens)
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "get proxy access tokens by account: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||||
|
token, err := s.GetProxyAccessTokenByID(ctx, LockingStrengthNone, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return token.IsValid(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error) {
|
||||||
|
tx := s.db
|
||||||
|
if lockStrength != LockingStrengthNone {
|
||||||
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var token types.ProxyAccessToken
|
||||||
|
result := tx.Take(&token, idQueryCondition, tokenID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "proxy access token not found")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "get proxy access token by ID: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||||
result := s.db.Model(&types.ProxyAccessToken{}).
|
result := s.db.Model(&types.ProxyAccessToken{}).
|
||||||
@@ -5439,13 +5478,29 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy or creates a new entry if it doesn't exist
|
func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||||
|
now := time.Now()
|
||||||
|
result := s.db.
|
||||||
|
Model(&proxy.Proxy{}).
|
||||||
|
Where("id = ?", proxyID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"status": proxy.StatusDisconnected,
|
||||||
|
"disconnected_at": now,
|
||||||
|
"last_seen": now,
|
||||||
|
})
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect proxy: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to disconnect proxy")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
result := s.db.
|
result := s.db.
|
||||||
Model(&proxy.Proxy{}).
|
Model(&proxy.Proxy{}).
|
||||||
Where("id = ? AND status = ?", proxyID, "connected").
|
Where("id = ? AND status = ?", proxyID, proxy.StatusConnected).
|
||||||
Update("last_seen", now)
|
Update("last_seen", now)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -5471,13 +5526,15 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
// GetActiveProxyClusterAddresses returns the unique cluster addresses of active
|
||||||
|
// shared proxies (those without an account scope). BYOP cluster addresses are
|
||||||
|
// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them.
|
||||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||||
var addresses []string
|
var addresses []string
|
||||||
|
|
||||||
result := s.db.
|
result := s.db.
|
||||||
Model(&proxy.Proxy{}).
|
Model(&proxy.Proxy{}).
|
||||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
|
||||||
Distinct("cluster_address").
|
Distinct("cluster_address").
|
||||||
Pluck("cluster_address", &addresses)
|
Pluck("cluster_address", &addresses)
|
||||||
|
|
||||||
@@ -5489,13 +5546,75 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
|
|||||||
return addresses, nil
|
return addresses, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count.
|
func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||||
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
var addresses []string
|
||||||
|
|
||||||
|
result := s.db.
|
||||||
|
Model(&proxy.Proxy{}).
|
||||||
|
Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)).
|
||||||
|
Distinct("cluster_address").
|
||||||
|
Pluck("cluster_address", &addresses)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses for account")
|
||||||
|
}
|
||||||
|
|
||||||
|
return addresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||||
|
var p proxy.Proxy
|
||||||
|
result := s.db.Where("account_id = ?", accountID).Take(&p)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "proxy not found for account")
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(status.Internal, "get proxy by account ID: %v", result.Error)
|
||||||
|
}
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
result := s.db.Model(&proxy.Proxy{}).Where("account_id = ?", accountID).Count(&count)
|
||||||
|
if result.Error != nil {
|
||||||
|
return 0, status.Errorf(status.Internal, "count proxies by account ID: %v", result.Error)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
result := s.db.
|
||||||
|
Model(&proxy.Proxy{}).
|
||||||
|
Where("cluster_address = ? AND (account_id IS NULL OR account_id != ?)", clusterAddress, accountID).
|
||||||
|
Count(&count)
|
||||||
|
if result.Error != nil {
|
||||||
|
return false, status.Errorf(status.Internal, "check cluster address conflict: %v", result.Error)
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||||
|
result := s.db.
|
||||||
|
Where("cluster_address = ? AND account_id = ?", clusterAddress, accountID).
|
||||||
|
Delete(&proxy.Proxy{})
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "delete account cluster: %v", result.Error)
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "cluster not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||||
var clusters []proxy.Cluster
|
var clusters []proxy.Cluster
|
||||||
|
|
||||||
result := s.db.Model(&proxy.Proxy{}).
|
result := s.db.Model(&proxy.Proxy{}).
|
||||||
Select("cluster_address as address, COUNT(*) as connected_proxies").
|
Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted").
|
||||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)",
|
||||||
|
proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID).
|
||||||
Group("cluster_address").
|
Group("cluster_address").
|
||||||
Scan(&clusters)
|
Scan(&clusters)
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,9 @@ type Store interface {
|
|||||||
|
|
||||||
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error)
|
||||||
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error)
|
||||||
|
GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.ProxyAccessToken, error)
|
||||||
|
GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types.ProxyAccessToken, error)
|
||||||
|
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||||
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error
|
||||||
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
RevokeProxyAccessToken(ctx context.Context, tokenID string) error
|
||||||
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error
|
||||||
@@ -284,13 +287,19 @@ type Store interface {
|
|||||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||||
|
|
||||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||||
|
DisconnectProxy(ctx context.Context, proxyID string) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||||
|
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||||
|
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||||
|
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||||
|
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||||
|
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
|
||||||
|
|
||||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||||
|
|
||||||
@@ -494,6 +503,9 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
|||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
||||||
},
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.DropIndex[proxy.Proxy](ctx, db, "idx_proxy_account_id_unique")
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -165,19 +165,6 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClusterSupportsCrowdSec mocks base method.
|
|
||||||
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
|
||||||
ret0, _ := ret[0].(*bool)
|
|
||||||
return ret0
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
|
||||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
|
||||||
}
|
|
||||||
// Close mocks base method.
|
// Close mocks base method.
|
||||||
func (m *MockStore) Close(ctx context.Context) error {
|
func (m *MockStore) Close(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -236,6 +223,21 @@ func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength,
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountProxiesByAccountID mocks base method.
|
||||||
|
func (m *MockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "CountProxiesByAccountID", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(int64)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountProxiesByAccountID indicates an expected call of CountProxiesByAccountID.
|
||||||
|
func (mr *MockStoreMockRecorder) CountProxiesByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountProxiesByAccountID", reflect.TypeOf((*MockStore)(nil).CountProxiesByAccountID), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateAccessLog mocks base method.
|
// CreateAccessLog mocks base method.
|
||||||
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
|
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -574,6 +576,20 @@ func (mr *MockStoreMockRecorder) DeletePostureChecks(ctx, accountID, postureChec
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePostureChecks", reflect.TypeOf((*MockStore)(nil).DeletePostureChecks), ctx, accountID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster mocks base method.
|
||||||
|
func (m *MockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
|
||||||
|
func (mr *MockStoreMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockStore)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteRoute mocks base method.
|
// DeleteRoute mocks base method.
|
||||||
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
func (m *MockStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -714,6 +730,20 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DisconnectProxy mocks base method.
|
||||||
|
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DisconnectProxy", ctx, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisconnectProxy indicates an expected call of DisconnectProxy.
|
||||||
|
func (mr *MockStoreMockRecorder) DisconnectProxy(ctx, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectProxy", reflect.TypeOf((*MockStore)(nil).DisconnectProxy), ctx, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
// EphemeralServiceExists mocks base method.
|
// EphemeralServiceExists mocks base method.
|
||||||
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -1300,19 +1330,34 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{})
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveProxyClusters mocks base method.
|
// GetActiveProxyClusterAddressesForAccount mocks base method.
|
||||||
func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) {
|
func (m *MockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx)
|
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddressesForAccount", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveProxyClusterAddressesForAccount indicates an expected call of GetActiveProxyClusterAddressesForAccount.
|
||||||
|
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddressesForAccount", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddressesForAccount), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveProxyClusters mocks base method.
|
||||||
|
func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID)
|
||||||
ret0, _ := ret[0].([]proxy.Cluster)
|
ret0, _ := ret[0].([]proxy.Cluster)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters.
|
||||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call {
|
func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllAccounts mocks base method.
|
// GetAllAccounts mocks base method.
|
||||||
@@ -1388,6 +1433,20 @@ func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr int
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetClusterSupportsCrowdSec mocks base method.
|
||||||
|
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
||||||
|
ret0, _ := ret[0].(*bool)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
||||||
|
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
// GetClusterSupportsCustomPorts mocks base method.
|
// GetClusterSupportsCustomPorts mocks base method.
|
||||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -1957,6 +2016,51 @@ func (mr *MockStoreMockRecorder) GetProxyAccessTokenByHashedToken(ctx, lockStren
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByHashedToken", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByHashedToken), ctx, lockStrength, hashedToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProxyAccessTokenByID mocks base method.
|
||||||
|
func (m *MockStore) GetProxyAccessTokenByID(ctx context.Context, lockStrength LockingStrength, tokenID string) (*types2.ProxyAccessToken, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetProxyAccessTokenByID", ctx, lockStrength, tokenID)
|
||||||
|
ret0, _ := ret[0].(*types2.ProxyAccessToken)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyAccessTokenByID indicates an expected call of GetProxyAccessTokenByID.
|
||||||
|
func (mr *MockStoreMockRecorder) GetProxyAccessTokenByID(ctx, lockStrength, tokenID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokenByID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokenByID), ctx, lockStrength, tokenID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyAccessTokensByAccountID mocks base method.
|
||||||
|
func (m *MockStore) GetProxyAccessTokensByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types2.ProxyAccessToken, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetProxyAccessTokensByAccountID", ctx, lockStrength, accountID)
|
||||||
|
ret0, _ := ret[0].([]*types2.ProxyAccessToken)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyAccessTokensByAccountID indicates an expected call of GetProxyAccessTokensByAccountID.
|
||||||
|
func (mr *MockStoreMockRecorder) GetProxyAccessTokensByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyAccessTokensByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyAccessTokensByAccountID), ctx, lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyByAccountID mocks base method.
|
||||||
|
func (m *MockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetProxyByAccountID", ctx, accountID)
|
||||||
|
ret0, _ := ret[0].(*proxy.Proxy)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxyByAccountID indicates an expected call of GetProxyByAccountID.
|
||||||
|
func (mr *MockStoreMockRecorder) GetProxyByAccountID(ctx, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxyByAccountID", reflect.TypeOf((*MockStore)(nil).GetProxyByAccountID), ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetResourceGroups mocks base method.
|
// GetResourceGroups mocks base method.
|
||||||
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
func (m *MockStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types2.Group, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -2389,6 +2493,21 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsClusterAddressConflicting mocks base method.
|
||||||
|
func (m *MockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "IsClusterAddressConflicting", ctx, clusterAddress, accountID)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClusterAddressConflicting indicates an expected call of IsClusterAddressConflicting.
|
||||||
|
func (mr *MockStoreMockRecorder) IsClusterAddressConflicting(ctx, clusterAddress, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressConflicting", reflect.TypeOf((*MockStore)(nil).IsClusterAddressConflicting), ctx, clusterAddress, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// IsPrimaryAccount mocks base method.
|
// IsPrimaryAccount mocks base method.
|
||||||
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -2405,6 +2524,21 @@ func (mr *MockStoreMockRecorder) IsPrimaryAccount(ctx, accountID interface{}) *g
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPrimaryAccount", reflect.TypeOf((*MockStore)(nil).IsPrimaryAccount), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsProxyAccessTokenValid mocks base method.
|
||||||
|
func (m *MockStore) IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "IsProxyAccessTokenValid", ctx, tokenID)
|
||||||
|
ret0, _ := ret[0].(bool)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProxyAccessTokenValid indicates an expected call of IsProxyAccessTokenValid.
|
||||||
|
func (mr *MockStoreMockRecorder) IsProxyAccessTokenValid(ctx, tokenID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProxyAccessTokenValid", reflect.TypeOf((*MockStore)(nil).IsProxyAccessTokenValid), ctx, tokenID)
|
||||||
|
}
|
||||||
|
|
||||||
// ListCustomDomains mocks base method.
|
// ListCustomDomains mocks base method.
|
||||||
func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
|
func (m *MockStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
@@ -27,7 +26,6 @@ import (
|
|||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
@@ -110,16 +108,9 @@ type Account struct {
|
|||||||
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
|
||||||
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
|
||||||
|
|
||||||
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
|
|
||||||
nmapInitOnce *sync.Once `gorm:"-"`
|
|
||||||
|
|
||||||
ReverseProxyFreeDomainNonce string
|
ReverseProxyFreeDomainNonce string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) InitOnce() {
|
|
||||||
a.nmapInitOnce = &sync.Once{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// this class is used by gorm only
|
// this class is used by gorm only
|
||||||
type PrimaryAccountInfo struct {
|
type PrimaryAccountInfo struct {
|
||||||
IsDomainPrimaryAccount bool
|
IsDomainPrimaryAccount bool
|
||||||
@@ -155,108 +146,6 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
|
|||||||
o.SignupFormPending == onboarding.SignupFormPending
|
o.SignupFormPending == onboarding.SignupFormPending
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
|
|
||||||
// from the ACL peers that have distribution groups associated with the peer ID.
|
|
||||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
|
||||||
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
|
||||||
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
|
|
||||||
peerRoutesMembership := make(LookupMap)
|
|
||||||
for _, r := range append(routes, peerDisabledRoutes...) {
|
|
||||||
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range aclPeers {
|
|
||||||
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
|
|
||||||
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
|
|
||||||
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
|
||||||
routes = append(routes, filteredRoutes...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
|
|
||||||
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
|
||||||
var filteredRoutes []*route.Route
|
|
||||||
for _, r := range routes {
|
|
||||||
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
|
||||||
if !found {
|
|
||||||
filteredRoutes = append(filteredRoutes, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filteredRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
|
|
||||||
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
|
||||||
var filteredRoutes []*route.Route
|
|
||||||
for _, r := range routes {
|
|
||||||
for _, groupID := range r.Groups {
|
|
||||||
_, found := groupListMap[groupID]
|
|
||||||
if found {
|
|
||||||
filteredRoutes = append(filteredRoutes, r)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filteredRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
|
|
||||||
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
|
|
||||||
// If the given is not a routing peer, then the lists are empty.
|
|
||||||
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
|
||||||
|
|
||||||
peer := a.GetPeer(peerID)
|
|
||||||
if peer == nil {
|
|
||||||
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
|
|
||||||
return enabledRoutes, disabledRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
seenRoute := make(map[route.ID]struct{})
|
|
||||||
|
|
||||||
takeRoute := func(r *route.Route, id string) {
|
|
||||||
if _, ok := seenRoute[r.ID]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
seenRoute[r.ID] = struct{}{}
|
|
||||||
|
|
||||||
if r.Enabled {
|
|
||||||
r.Peer = peer.Key
|
|
||||||
enabledRoutes = append(enabledRoutes, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
disabledRoutes = append(disabledRoutes, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range a.Routes {
|
|
||||||
for _, groupID := range r.PeerGroups {
|
|
||||||
group := a.GetGroup(groupID)
|
|
||||||
if group == nil {
|
|
||||||
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, id := range group.Peers {
|
|
||||||
if id != peerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
newPeerRoute := r.Copy()
|
|
||||||
newPeerRoute.Peer = id
|
|
||||||
newPeerRoute.PeerGroups = nil
|
|
||||||
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
|
|
||||||
takeRoute(newPeerRoute, id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if r.Peer == peerID {
|
|
||||||
takeRoute(r.Copy(), peerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return enabledRoutes, disabledRoutes
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||||
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
|
||||||
var routes []*route.Route
|
var routes []*route.Route
|
||||||
@@ -276,106 +165,6 @@ func (a *Account) GetGroup(groupID string) *Group {
|
|||||||
return a.Groups[groupID]
|
return a.Groups[groupID]
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
|
||||||
func (a *Account) GetPeerNetworkMap(
|
|
||||||
ctx context.Context,
|
|
||||||
peerID string,
|
|
||||||
peersCustomZone nbdns.CustomZone,
|
|
||||||
accountZones []*zones.Zone,
|
|
||||||
validatedPeersMap map[string]struct{},
|
|
||||||
resourcePolicies map[string][]*Policy,
|
|
||||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
|
||||||
groupIDToUserIDs map[string][]string,
|
|
||||||
) *NetworkMap {
|
|
||||||
start := time.Now()
|
|
||||||
peer := a.Peers[peerID]
|
|
||||||
if peer == nil {
|
|
||||||
return &NetworkMap{
|
|
||||||
Network: a.Network.Copy(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := validatedPeersMap[peerID]; !ok {
|
|
||||||
return &NetworkMap{
|
|
||||||
Network: a.Network.Copy(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peerGroups := a.GetPeerGroups(peerID)
|
|
||||||
|
|
||||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
|
||||||
// exclude expired peers
|
|
||||||
var peersToConnect []*nbpeer.Peer
|
|
||||||
var expiredPeers []*nbpeer.Peer
|
|
||||||
for _, p := range aclPeers {
|
|
||||||
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
|
|
||||||
if a.Settings.PeerLoginExpirationEnabled && expired {
|
|
||||||
expiredPeers = append(expiredPeers, p)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
peersToConnect = append(peersToConnect, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
|
|
||||||
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
|
|
||||||
var networkResourcesFirewallRules []*RouteFirewallRule
|
|
||||||
if isRouter {
|
|
||||||
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
|
|
||||||
}
|
|
||||||
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
|
|
||||||
|
|
||||||
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
|
|
||||||
dnsUpdate := nbdns.Config{
|
|
||||||
ServiceEnable: dnsManagementStatus,
|
|
||||||
}
|
|
||||||
|
|
||||||
if dnsManagementStatus {
|
|
||||||
var zones []nbdns.CustomZone
|
|
||||||
|
|
||||||
if peersCustomZone.Domain != "" {
|
|
||||||
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
|
|
||||||
zones = append(zones, nbdns.CustomZone{
|
|
||||||
Domain: peersCustomZone.Domain,
|
|
||||||
Records: records,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
|
||||||
zones = append(zones, filteredAccountZones...)
|
|
||||||
|
|
||||||
dnsUpdate.CustomZones = zones
|
|
||||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
nm := &NetworkMap{
|
|
||||||
Peers: peersToConnectIncludingRouters,
|
|
||||||
Network: a.Network.Copy(),
|
|
||||||
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
|
|
||||||
DNSConfig: dnsUpdate,
|
|
||||||
OfflinePeers: expiredPeers,
|
|
||||||
FirewallRules: firewallRules,
|
|
||||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
|
||||||
AuthorizedUsers: authorizedUsers,
|
|
||||||
EnableSSH: enableSSH,
|
|
||||||
}
|
|
||||||
|
|
||||||
if metrics != nil {
|
|
||||||
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
|
|
||||||
metrics.CountNetworkMapObjects(objectCount)
|
|
||||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
|
||||||
|
|
||||||
if objectCount > 5000 {
|
|
||||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
|
|
||||||
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
|
|
||||||
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nm
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) addNetworksRoutingPeers(
|
func (a *Account) addNetworksRoutingPeers(
|
||||||
networkResourcesRoutes []*route.Route,
|
networkResourcesRoutes []*route.Route,
|
||||||
peer *nbpeer.Peer,
|
peer *nbpeer.Peer,
|
||||||
@@ -421,39 +210,6 @@ func (a *Account) addNetworksRoutingPeers(
|
|||||||
return peersToConnect
|
return peersToConnect
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
|
||||||
groupList := account.GetPeerGroups(peerID)
|
|
||||||
|
|
||||||
var peerNSGroups []*nbdns.NameServerGroup
|
|
||||||
|
|
||||||
for _, nsGroup := range account.NameServerGroups {
|
|
||||||
if !nsGroup.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, gID := range nsGroup.Groups {
|
|
||||||
_, found := groupList[gID]
|
|
||||||
if found {
|
|
||||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
|
||||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return peerNSGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
|
||||||
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
|
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
label, err := GetPeerHostLabel(peer.Name, peerLabels)
|
label, err := GetPeerHostLabel(peer.Name, peerLabels)
|
||||||
@@ -800,19 +556,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
|
|||||||
return grps
|
return grps
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
|
|
||||||
peerGroups := a.GetPeerGroups(peerID)
|
|
||||||
enabled := true
|
|
||||||
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
|
|
||||||
_, found := peerGroups[groupID]
|
|
||||||
if found {
|
|
||||||
enabled = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
func (a *Account) GetPeerGroups(peerID string) LookupMap {
|
||||||
groupList := make(LookupMap)
|
groupList := make(LookupMap)
|
||||||
for groupID, group := range a.Groups {
|
for groupID, group := range a.Groups {
|
||||||
@@ -941,8 +684,6 @@ func (a *Account) Copy() *Account {
|
|||||||
NetworkResources: networkResources,
|
NetworkResources: networkResources,
|
||||||
Services: services,
|
Services: services,
|
||||||
Onboarding: a.Onboarding,
|
Onboarding: a.Onboarding,
|
||||||
NetworkMapCache: a.NetworkMapCache,
|
|
||||||
nmapInitOnce: a.nmapInitOnce,
|
|
||||||
Domains: domains,
|
Domains: domains,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1304,31 +1045,6 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
|
|
||||||
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
|
|
||||||
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
|
|
||||||
|
|
||||||
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
|
|
||||||
for _, route := range enabledRoutes {
|
|
||||||
// If no access control groups are specified, accept all traffic.
|
|
||||||
if len(route.AccessControlGroups) == 0 {
|
|
||||||
defaultPermit := getDefaultPermit(route)
|
|
||||||
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
distributionPeers := a.getDistributionGroupsPeers(route)
|
|
||||||
|
|
||||||
for _, accessGroup := range route.AccessControlGroups {
|
|
||||||
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
|
|
||||||
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
|
|
||||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return routesFirewallRules
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||||
var fwRules []*RouteFirewallRule
|
var fwRules []*RouteFirewallRule
|
||||||
for _, policy := range policies {
|
for _, policy := range policies {
|
||||||
@@ -1387,50 +1103,6 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
|
|||||||
return distributionGroupPeers
|
return distributionGroupPeers
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
|
|
||||||
distPeers := make(map[string]struct{})
|
|
||||||
for _, id := range route.Groups {
|
|
||||||
group := a.Groups[id]
|
|
||||||
if group == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, pID := range group.Peers {
|
|
||||||
distPeers[pID] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return distPeers
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
|
|
||||||
var rules []*RouteFirewallRule
|
|
||||||
|
|
||||||
sources := []string{"0.0.0.0/0"}
|
|
||||||
if route.Network.Addr().Is6() {
|
|
||||||
sources = []string{"::/0"}
|
|
||||||
}
|
|
||||||
rule := RouteFirewallRule{
|
|
||||||
SourceRanges: sources,
|
|
||||||
Action: string(PolicyTrafficActionAccept),
|
|
||||||
Destination: route.Network.String(),
|
|
||||||
Protocol: string(PolicyRuleProtocolALL),
|
|
||||||
Domains: route.Domains,
|
|
||||||
IsDynamic: route.IsDynamic(),
|
|
||||||
RouteID: route.ID,
|
|
||||||
}
|
|
||||||
|
|
||||||
rules = append(rules, &rule)
|
|
||||||
|
|
||||||
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
|
|
||||||
if route.IsDynamic() {
|
|
||||||
ruleV6 := rule
|
|
||||||
ruleV6.SourceRanges = []string{"::/0"}
|
|
||||||
rules = append(rules, &ruleV6)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rules
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
|
||||||
// and returns a list of policies that have rules with destinations matching the specified groups.
|
// and returns a list of policies that have rules with destinations matching the specified groups.
|
||||||
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
|
||||||
@@ -1508,65 +1180,6 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
|||||||
return resourcePolicies
|
return resourcePolicies
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
|
||||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
|
||||||
var isRoutingPeer bool
|
|
||||||
var routes []*route.Route
|
|
||||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
|
||||||
|
|
||||||
for _, resource := range a.NetworkResources {
|
|
||||||
if !resource.Enabled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var addSourcePeers bool
|
|
||||||
|
|
||||||
networkRoutingPeers, exists := routers[resource.NetworkID]
|
|
||||||
if exists {
|
|
||||||
if router, ok := networkRoutingPeers[peerID]; ok {
|
|
||||||
isRoutingPeer, addSourcePeers = true, true
|
|
||||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
addedResourceRoute := false
|
|
||||||
for _, policy := range resourcePolicies[resource.ID] {
|
|
||||||
var peers []string
|
|
||||||
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
|
||||||
peers = []string{policy.Rules[0].SourceResource.ID}
|
|
||||||
} else {
|
|
||||||
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
|
||||||
}
|
|
||||||
if addSourcePeers {
|
|
||||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
|
||||||
allSourcePeers[pID] = struct{}{}
|
|
||||||
}
|
|
||||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
|
||||||
// add routes for the resource if the peer is in the distribution group
|
|
||||||
for peerId, router := range networkRoutingPeers {
|
|
||||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
|
||||||
}
|
|
||||||
addedResourceRoute = true
|
|
||||||
}
|
|
||||||
if addedResourceRoute {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return isRoutingPeer, routes, allSourcePeers
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
|
||||||
var dest []string
|
|
||||||
for _, peerID := range inputPeers {
|
|
||||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
|
||||||
dest = append(dest, peerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||||
for _, groupID := range groups {
|
for _, groupID := range groups {
|
||||||
@@ -1658,22 +1271,6 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNetworkResourcesRoutes convert the network resources list to routes list.
|
|
||||||
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
|
|
||||||
resourceAppliedPolicies := resourcePolicies[resource.ID]
|
|
||||||
|
|
||||||
var routes []*route.Route
|
|
||||||
// distribute the resource routes only if there is policy applied to it
|
|
||||||
if len(resourceAppliedPolicies) > 0 {
|
|
||||||
peer := a.GetPeer(peerId)
|
|
||||||
if peer != nil {
|
|
||||||
routes = append(routes, resource.ToRoute(peer, router))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
|
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
|
||||||
routers := make(map[string]map[string]*routerTypes.NetworkRouter)
|
routers := make(map[string]map[string]*routerTypes.NetworkRouter)
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -19,7 +17,6 @@ import (
|
|||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -451,402 +448,6 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
|
|||||||
require.Len(t, result, 0)
|
require.Len(t, result, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
accID = "accountID"
|
|
||||||
network1ID = "network1ID"
|
|
||||||
group1ID = "group1"
|
|
||||||
accNetResourcePeer1ID = "peer1"
|
|
||||||
accNetResourcePeer2ID = "peer2"
|
|
||||||
accNetResourceRouter1ID = "router1"
|
|
||||||
accNetResource1ID = "resource1ID"
|
|
||||||
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
|
|
||||||
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
|
|
||||||
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
|
|
||||||
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
|
|
||||||
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
|
|
||||||
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
|
|
||||||
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
|
|
||||||
)
|
|
||||||
|
|
||||||
func getBasicAccountsWithResource() *Account {
|
|
||||||
return &Account{
|
|
||||||
Id: accID,
|
|
||||||
Peers: map[string]*nbpeer.Peer{
|
|
||||||
accNetResourcePeer1ID: {
|
|
||||||
ID: accNetResourcePeer1ID,
|
|
||||||
AccountID: accID,
|
|
||||||
Key: "peer1Key",
|
|
||||||
IP: accNetResourcePeer1IP,
|
|
||||||
Meta: nbpeer.PeerSystemMeta{
|
|
||||||
GoOS: "linux",
|
|
||||||
WtVersion: "0.35.1",
|
|
||||||
KernelVersion: "4.4.0",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
accNetResourcePeer2ID: {
|
|
||||||
ID: accNetResourcePeer2ID,
|
|
||||||
AccountID: accID,
|
|
||||||
Key: "peer2Key",
|
|
||||||
IP: accNetResourcePeer2IP,
|
|
||||||
Meta: nbpeer.PeerSystemMeta{
|
|
||||||
GoOS: "windows",
|
|
||||||
WtVersion: "0.34.1",
|
|
||||||
KernelVersion: "4.4.0",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
accNetResourceRouter1ID: {
|
|
||||||
ID: accNetResourceRouter1ID,
|
|
||||||
AccountID: accID,
|
|
||||||
Key: "router1Key",
|
|
||||||
IP: accNetResourceRouter1IP,
|
|
||||||
Meta: nbpeer.PeerSystemMeta{
|
|
||||||
GoOS: "linux",
|
|
||||||
WtVersion: "0.35.1",
|
|
||||||
KernelVersion: "4.4.0",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Groups: map[string]*Group{
|
|
||||||
group1ID: {
|
|
||||||
ID: group1ID,
|
|
||||||
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Networks: []*networkTypes.Network{
|
|
||||||
{
|
|
||||||
ID: network1ID,
|
|
||||||
AccountID: accID,
|
|
||||||
Name: "network1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
|
||||||
{
|
|
||||||
ID: accNetResourceRouter1ID,
|
|
||||||
NetworkID: network1ID,
|
|
||||||
AccountID: accID,
|
|
||||||
Peer: accNetResourceRouter1ID,
|
|
||||||
PeerGroups: []string{},
|
|
||||||
Masquerade: false,
|
|
||||||
Metric: 100,
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NetworkResources: []*resourceTypes.NetworkResource{
|
|
||||||
{
|
|
||||||
ID: accNetResource1ID,
|
|
||||||
AccountID: accID,
|
|
||||||
NetworkID: network1ID,
|
|
||||||
Address: "10.10.10.0/24",
|
|
||||||
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
|
|
||||||
Type: resourceTypes.NetworkResourceType("subnet"),
|
|
||||||
Enabled: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Policies: []*Policy{
|
|
||||||
{
|
|
||||||
ID: "policy1ID",
|
|
||||||
AccountID: accID,
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: "rule1ID",
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{group1ID},
|
|
||||||
DestinationResource: Resource{
|
|
||||||
ID: accNetResource1ID,
|
|
||||||
Type: "Host",
|
|
||||||
},
|
|
||||||
Protocol: PolicyRuleProtocolTCP,
|
|
||||||
Ports: []string{"80"},
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
SourcePostureChecks: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PostureChecks: []*posture.Checks{
|
|
||||||
{
|
|
||||||
ID: accNetResourceRestrictPostureCheckID,
|
|
||||||
Name: accNetResourceRestrictPostureCheckID,
|
|
||||||
Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
|
||||||
MinVersion: "0.35.0",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: accNetResourceRelaxedPostureCheckID,
|
|
||||||
Name: accNetResourceRelaxedPostureCheckID,
|
|
||||||
Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
|
||||||
MinVersion: "0.0.1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: accNetResourceLockedPostureCheckID,
|
|
||||||
Name: accNetResourceLockedPostureCheckID,
|
|
||||||
Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
|
||||||
MinVersion: "7.7.7",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: accNetResourceLinuxPostureCheckID,
|
|
||||||
Name: accNetResourceLinuxPostureCheckID,
|
|
||||||
Checks: posture.ChecksDefinition{
|
|
||||||
OSVersionCheck: &posture.OSVersionCheck{
|
|
||||||
Linux: &posture.MinKernelVersionCheck{
|
|
||||||
MinKernelVersion: "0.0.0"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
|
|
||||||
account := getBasicAccountsWithResource()
|
|
||||||
|
|
||||||
// all peers should match the policy
|
|
||||||
|
|
||||||
// validate for peer1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate for peer2
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate routes for router1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.True(t, isRouter, "should be router")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
|
||||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
|
||||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate rules for router1
|
|
||||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
|
||||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
|
||||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
|
||||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
|
||||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
|
||||||
}
|
|
||||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
|
|
||||||
account := getBasicAccountsWithResource()
|
|
||||||
|
|
||||||
// should allow peer1 to match the policy
|
|
||||||
policy := account.Policies[0]
|
|
||||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
|
||||||
|
|
||||||
// validate for peer1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate for peer2
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate routes for router1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.True(t, isRouter, "should be router")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
|
||||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate rules for router1
|
|
||||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
|
||||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
|
||||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
|
||||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
|
||||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
|
||||||
}
|
|
||||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
|
|
||||||
account := getBasicAccountsWithResource()
|
|
||||||
|
|
||||||
// should not match any peer
|
|
||||||
policy := account.Policies[0]
|
|
||||||
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
|
|
||||||
|
|
||||||
// validate for peer1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate for peer2
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate routes for router1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.True(t, isRouter, "should be router")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate rules for router1
|
|
||||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
|
||||||
assert.Len(t, rules, 0, "expected rules count don't match")
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
|
|
||||||
account := getBasicAccountsWithResource()
|
|
||||||
|
|
||||||
// should allow peer1 to match the policy
|
|
||||||
policy := account.Policies[0]
|
|
||||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
|
||||||
|
|
||||||
// should allow peer1 and peer2 to match the policy
|
|
||||||
newPolicy := &Policy{
|
|
||||||
ID: "policy2ID",
|
|
||||||
AccountID: accID,
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: "policy2ID",
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{group1ID},
|
|
||||||
DestinationResource: Resource{
|
|
||||||
ID: accNetResource1ID,
|
|
||||||
Type: "Host",
|
|
||||||
},
|
|
||||||
Protocol: PolicyRuleProtocolTCP,
|
|
||||||
Ports: []string{"22"},
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Policies = append(account.Policies, newPolicy)
|
|
||||||
|
|
||||||
// validate for peer1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate for peer2
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate routes for router1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.True(t, isRouter, "should be router")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
|
||||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
|
||||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate rules for router1
|
|
||||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
|
||||||
assert.Len(t, rules, 2, "expected rules count don't match")
|
|
||||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
|
||||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
|
||||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
|
||||||
}
|
|
||||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
|
|
||||||
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
|
|
||||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
|
|
||||||
}
|
|
||||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
|
|
||||||
account := getBasicAccountsWithResource()
|
|
||||||
|
|
||||||
// two posture checks should match only the peers that match both checks
|
|
||||||
policy := account.Policies[0]
|
|
||||||
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
|
|
||||||
|
|
||||||
// validate for peer1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate for peer2
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.False(t, isRouter, "expected router status")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate routes for router1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.True(t, isRouter, "should be router")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
|
||||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
|
||||||
|
|
||||||
// validate rules for router1
|
|
||||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
|
||||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
|
||||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
|
||||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
|
||||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
|
||||||
}
|
|
||||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
|
||||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
|
|
||||||
account := getBasicAccountsWithResource()
|
|
||||||
|
|
||||||
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
|
|
||||||
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
|
||||||
ID: "router2Id",
|
|
||||||
NetworkID: network1ID,
|
|
||||||
AccountID: accID,
|
|
||||||
Peer: "router2Id",
|
|
||||||
})
|
|
||||||
|
|
||||||
// validate routes for router1
|
|
||||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
|
||||||
assert.True(t, isRouter, "should be router")
|
|
||||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
|
||||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Holder struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
accounts map[string]*Account
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHolder() *Holder {
|
|
||||||
return &Holder{
|
|
||||||
accounts: make(map[string]*Account),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Holder) GetAccount(id string) *Account {
|
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
return h.accounts[id]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Holder) AddAccount(account *Account) {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
a := h.accounts[account.Id]
|
|
||||||
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
h.accounts[account.Id] = account
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
if acc, ok := h.accounts[id]; ok {
|
|
||||||
return acc, nil
|
|
||||||
}
|
|
||||||
account, err := accGetter(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
h.accounts[id] = account
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
|
|
||||||
if a.NetworkMapCache != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
a.nmapInitOnce.Do(func() {
|
|
||||||
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
|
|
||||||
a.initNetworkMapBuilder(validatedPeers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) GetPeerNetworkMapExp(
|
|
||||||
ctx context.Context,
|
|
||||||
peerID string,
|
|
||||||
peersCustomZone nbdns.CustomZone,
|
|
||||||
accountZones []*zones.Zone,
|
|
||||||
validatedPeers map[string]struct{},
|
|
||||||
metrics *telemetry.AccountManagerMetrics,
|
|
||||||
) *NetworkMap {
|
|
||||||
a.initNetworkMapBuilder(validatedPeers)
|
|
||||||
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
|
||||||
if a.NetworkMapCache == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
|
|
||||||
if a.NetworkMapCache == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
|
||||||
if a.NetworkMapCache == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
|
||||||
if a.NetworkMapCache == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
a.NetworkMapCache.UpdatePeer(peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
|
|
||||||
a.initNetworkMapBuilder(validatedPeers)
|
|
||||||
}
|
|
||||||
@@ -1,592 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
|
|
||||||
account := createTestAccount()
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
peerID := testingPeerID
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
pid := fmt.Sprintf("peer-%d", i)
|
|
||||||
if pid == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[pid] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCustomZone := nbdns.CustomZone{}
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
nil,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
|
|
||||||
components := account.GetPeerNetworkMapComponents(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if components == nil {
|
|
||||||
t.Fatal("GetPeerNetworkMapComponents returned nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
|
||||||
|
|
||||||
if newNetworkMap == nil {
|
|
||||||
t.Fatal("CalculateNetworkMapFromComponents returned nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
|
|
||||||
account := createTestAccount()
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
peerID := testingPeerID
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
pid := fmt.Sprintf("peer-%d", i)
|
|
||||||
if pid == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[pid] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCustomZone := nbdns.CustomZone{}
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
nil,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
|
|
||||||
components := account.GetPeerNetworkMapComponents(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
|
|
||||||
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
|
|
||||||
|
|
||||||
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
|
||||||
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
|
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
|
||||||
|
|
||||||
componentsJSON, err := json.MarshalIndent(components, "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling components to JSON")
|
|
||||||
|
|
||||||
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
|
||||||
|
|
||||||
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
|
||||||
|
|
||||||
goldenDir := filepath.Join("testdata", "comparison")
|
|
||||||
err = os.MkdirAll(goldenDir, 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
|
|
||||||
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
|
|
||||||
require.NoError(t, err, "error writing legacy golden file")
|
|
||||||
|
|
||||||
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
|
|
||||||
err = os.WriteFile(newGoldenPath, newJSON, 0644)
|
|
||||||
require.NoError(t, err, "error writing components golden file")
|
|
||||||
|
|
||||||
componentsPath := filepath.Join(goldenDir, "components.json")
|
|
||||||
err = os.WriteFile(componentsPath, componentsJSON, 0644)
|
|
||||||
require.NoError(t, err, "error writing components golden file")
|
|
||||||
|
|
||||||
require.JSONEq(t, string(legacyJSON), string(newJSON),
|
|
||||||
"NetworkMaps from legacy and components approaches do not match.\n"+
|
|
||||||
"Legacy JSON saved to: %s\n"+
|
|
||||||
"Components JSON saved to: %s",
|
|
||||||
legacyGoldenPath, newGoldenPath)
|
|
||||||
|
|
||||||
t.Logf("✅ NetworkMaps are identical")
|
|
||||||
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
|
|
||||||
t.Logf(" Components NetworkMap: %s", newGoldenPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeAndSortNetworkMap(nm *NetworkMap) {
|
|
||||||
if nm == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(nm.Peers, func(i, j int) bool {
|
|
||||||
return nm.Peers[i].ID < nm.Peers[j].ID
|
|
||||||
})
|
|
||||||
|
|
||||||
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
|
|
||||||
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
|
|
||||||
})
|
|
||||||
|
|
||||||
sort.Slice(nm.Routes, func(i, j int) bool {
|
|
||||||
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
|
|
||||||
})
|
|
||||||
|
|
||||||
sort.Slice(nm.FirewallRules, func(i, j int) bool {
|
|
||||||
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
|
|
||||||
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
|
|
||||||
}
|
|
||||||
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
|
|
||||||
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
|
|
||||||
}
|
|
||||||
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
|
|
||||||
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
|
|
||||||
}
|
|
||||||
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
|
|
||||||
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
|
|
||||||
}
|
|
||||||
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
|
|
||||||
})
|
|
||||||
|
|
||||||
for i := range nm.RoutesFirewallRules {
|
|
||||||
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
|
|
||||||
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
|
|
||||||
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
|
|
||||||
}
|
|
||||||
|
|
||||||
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
|
|
||||||
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
|
|
||||||
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
|
|
||||||
}
|
|
||||||
for k := 0; k < minLen; k++ {
|
|
||||||
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
|
|
||||||
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
|
|
||||||
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
|
|
||||||
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
|
|
||||||
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
|
|
||||||
}
|
|
||||||
|
|
||||||
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
|
|
||||||
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
|
|
||||||
}
|
|
||||||
|
|
||||||
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
|
|
||||||
})
|
|
||||||
|
|
||||||
if nm.DNSConfig.CustomZones != nil {
|
|
||||||
for i := range nm.DNSConfig.CustomZones {
|
|
||||||
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
|
|
||||||
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nm.DNSConfig.NameServerGroups) != 0 {
|
|
||||||
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
|
|
||||||
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if legacy.Network.Serial != current.Network.Serial {
|
|
||||||
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacy.Peers) != len(current.Peers) {
|
|
||||||
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
|
|
||||||
}
|
|
||||||
|
|
||||||
legacyPeerIDs := make(map[string]bool)
|
|
||||||
for _, p := range legacy.Peers {
|
|
||||||
legacyPeerIDs[p.ID] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range current.Peers {
|
|
||||||
if !legacyPeerIDs[p.ID] {
|
|
||||||
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
|
|
||||||
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacy.FirewallRules) != len(current.FirewallRules) {
|
|
||||||
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacy.Routes) != len(current.Routes) {
|
|
||||||
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
|
|
||||||
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
|
|
||||||
}
|
|
||||||
|
|
||||||
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
|
|
||||||
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
numPeers = 100
|
|
||||||
devGroupID = "group-dev"
|
|
||||||
opsGroupID = "group-ops"
|
|
||||||
allGroupID = "group-all"
|
|
||||||
routeID = route.ID("route-main")
|
|
||||||
routeHA1ID = route.ID("route-ha-1")
|
|
||||||
routeHA2ID = route.ID("route-ha-2")
|
|
||||||
policyIDDevOps = "policy-dev-ops"
|
|
||||||
policyIDAll = "policy-all"
|
|
||||||
policyIDPosture = "policy-posture"
|
|
||||||
policyIDDrop = "policy-drop"
|
|
||||||
postureCheckID = "posture-check-ver"
|
|
||||||
networkResourceID = "res-database"
|
|
||||||
networkID = "net-database"
|
|
||||||
networkRouterID = "router-database"
|
|
||||||
nameserverGroupID = "ns-group-main"
|
|
||||||
testingPeerID = "peer-60"
|
|
||||||
expiredPeerID = "peer-98"
|
|
||||||
offlinePeerID = "peer-99"
|
|
||||||
routingPeerID = "peer-95"
|
|
||||||
testAccountID = "account-comparison-test"
|
|
||||||
)
|
|
||||||
|
|
||||||
func createTestAccount() *Account {
|
|
||||||
peers := make(map[string]*nbpeer.Peer)
|
|
||||||
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
|
||||||
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
ip := net.IP{100, 64, 0, byte(i + 1)}
|
|
||||||
wtVersion := "0.25.0"
|
|
||||||
if i%2 == 0 {
|
|
||||||
wtVersion = "0.40.0"
|
|
||||||
}
|
|
||||||
|
|
||||||
p := &nbpeer.Peer{
|
|
||||||
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if peerID == expiredPeerID {
|
|
||||||
p.LoginExpirationEnabled = true
|
|
||||||
pastTimestamp := time.Now().Add(-2 * time.Hour)
|
|
||||||
p.LastLogin = &pastTimestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
peers[peerID] = p
|
|
||||||
allGroupPeers = append(allGroupPeers, peerID)
|
|
||||||
if i < numPeers/2 {
|
|
||||||
devGroupPeers = append(devGroupPeers, peerID)
|
|
||||||
} else {
|
|
||||||
opsGroupPeers = append(opsGroupPeers, peerID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := map[string]*Group{
|
|
||||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
|
||||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
|
||||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
|
||||||
}
|
|
||||||
|
|
||||||
policies := []*Policy{
|
|
||||||
{
|
|
||||||
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
|
|
||||||
Rules: []*PolicyRule{{
|
|
||||||
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
|
|
||||||
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
|
||||||
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
|
|
||||||
Rules: []*PolicyRule{{
|
|
||||||
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
|
|
||||||
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
|
|
||||||
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
|
|
||||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
|
|
||||||
Rules: []*PolicyRule{{
|
|
||||||
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
|
|
||||||
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
|
|
||||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
|
|
||||||
SourcePostureChecks: []string{postureCheckID},
|
|
||||||
Rules: []*PolicyRule{{
|
|
||||||
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
|
|
||||||
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
|
||||||
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := map[route.ID]*route.Route{
|
|
||||||
routeID: {
|
|
||||||
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
|
|
||||||
Peer: peers["peer-75"].Key,
|
|
||||||
PeerID: "peer-75",
|
|
||||||
Description: "Route to internal resource", Enabled: true,
|
|
||||||
PeerGroups: []string{devGroupID, opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{devGroupID},
|
|
||||||
},
|
|
||||||
routeHA1ID: {
|
|
||||||
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
|
||||||
Peer: peers["peer-80"].Key,
|
|
||||||
PeerID: "peer-80",
|
|
||||||
Description: "HA Route 1", Enabled: true, Metric: 1000,
|
|
||||||
PeerGroups: []string{allGroupID},
|
|
||||||
Groups: []string{allGroupID},
|
|
||||||
AccessControlGroups: []string{allGroupID},
|
|
||||||
},
|
|
||||||
routeHA2ID: {
|
|
||||||
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
|
||||||
Peer: peers["peer-90"].Key,
|
|
||||||
PeerID: "peer-90",
|
|
||||||
Description: "HA Route 2", Enabled: true, Metric: 900,
|
|
||||||
PeerGroups: []string{devGroupID, opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{allGroupID},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
account := &Account{
|
|
||||||
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
|
||||||
Network: &Network{
|
|
||||||
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
|
||||||
},
|
|
||||||
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
|
|
||||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
|
||||||
nameserverGroupID: {
|
|
||||||
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
|
|
||||||
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PostureChecks: []*posture.Checks{
|
|
||||||
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
NetworkResources: []*resourceTypes.NetworkResource{
|
|
||||||
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
|
|
||||||
},
|
|
||||||
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
|
|
||||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
|
||||||
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
|
|
||||||
},
|
|
||||||
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range account.Policies {
|
|
||||||
p.AccountID = account.Id
|
|
||||||
}
|
|
||||||
for _, r := range account.Routes {
|
|
||||||
r.AccountID = account.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
return account
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLegacyNetworkMap(b *testing.B) {
|
|
||||||
account := createTestAccount()
|
|
||||||
ctx := context.Background()
|
|
||||||
peerID := testingPeerID
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
pid := fmt.Sprintf("peer-%d", i)
|
|
||||||
if pid != offlinePeerID {
|
|
||||||
validatedPeersMap[pid] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCustomZone := nbdns.CustomZone{}
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = account.GetPeerNetworkMap(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
nil,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkComponentsNetworkMap(b *testing.B) {
|
|
||||||
account := createTestAccount()
|
|
||||||
ctx := context.Background()
|
|
||||||
peerID := testingPeerID
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
pid := fmt.Sprintf("peer-%d", i)
|
|
||||||
if pid != offlinePeerID {
|
|
||||||
validatedPeersMap[pid] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCustomZone := nbdns.CustomZone{}
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
components := account.GetPeerNetworkMapComponents(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
_ = CalculateNetworkMapFromComponents(ctx, components)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkComponentsCreation(b *testing.B) {
|
|
||||||
account := createTestAccount()
|
|
||||||
ctx := context.Background()
|
|
||||||
peerID := testingPeerID
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
pid := fmt.Sprintf("peer-%d", i)
|
|
||||||
if pid != offlinePeerID {
|
|
||||||
validatedPeersMap[pid] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCustomZone := nbdns.CustomZone{}
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = account.GetPeerNetworkMapComponents(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkCalculationFromComponents(b *testing.B) {
|
|
||||||
account := createTestAccount()
|
|
||||||
ctx := context.Background()
|
|
||||||
peerID := testingPeerID
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
pid := fmt.Sprintf("peer-%d", i)
|
|
||||||
if pid != offlinePeerID {
|
|
||||||
validatedPeersMap[pid] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCustomZone := nbdns.CustomZone{}
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
|
||||||
|
|
||||||
components := account.GetPeerNetworkMapComponents(
|
|
||||||
ctx,
|
|
||||||
peerID,
|
|
||||||
peersCustomZone,
|
|
||||||
nil,
|
|
||||||
validatedPeersMap,
|
|
||||||
resourcePolicies,
|
|
||||||
routers,
|
|
||||||
groupIDToUserIDs,
|
|
||||||
)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = CalculateNetworkMapFromComponents(ctx, components)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -19,8 +19,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
|
|
||||||
|
|
||||||
type NetworkMapComponents struct {
|
type NetworkMapComponents struct {
|
||||||
PeerID string
|
PeerID string
|
||||||
|
|
||||||
|
|||||||
787
management/server/types/networkmap_components_test.go
Normal file
787
management/server/types/networkmap_components_test.go
Normal file
@@ -0,0 +1,787 @@
|
|||||||
|
package types_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap {
|
||||||
|
t.Helper()
|
||||||
|
return account.GetPeerNetworkMapFromComponents(
|
||||||
|
context.Background(),
|
||||||
|
peerID,
|
||||||
|
account.GetPeersCustomZone(context.Background(), "netbird.io"),
|
||||||
|
nil,
|
||||||
|
validatedPeers,
|
||||||
|
account.GetResourcePoliciesMap(),
|
||||||
|
account.GetResourceRoutersMap(),
|
||||||
|
nil,
|
||||||
|
account.GetActiveGroupUsers(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} {
|
||||||
|
excludeSet := make(map[string]struct{}, len(excludePeerIDs))
|
||||||
|
for _, id := range excludePeerIDs {
|
||||||
|
excludeSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
validated := make(map[string]struct{}, len(account.Peers))
|
||||||
|
for id := range account.Peers {
|
||||||
|
if _, excluded := excludeSet[id]; !excluded {
|
||||||
|
validated[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return validated
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerIDs(peers []*nbpeer.Peer) []string {
|
||||||
|
ids := make([]string, len(peers))
|
||||||
|
for i, p := range peers {
|
||||||
|
ids[i] = p.ID
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
assert.NotNil(t, nm)
|
||||||
|
assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy")
|
||||||
|
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy")
|
||||||
|
assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself")
|
||||||
|
assert.Empty(t, nm.OfflinePeers, "no expired peers expected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-intra-src", Name: "src <-> src", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"group-src"}, Destinations: []string{"group-src"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_FirewallRules(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated")
|
||||||
|
|
||||||
|
var hasAcceptAll bool
|
||||||
|
for _, rule := range nm.FirewallRules {
|
||||||
|
if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) {
|
||||||
|
hasAcceptAll = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_LoginExpiration(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
account.Settings.PeerLoginExpirationEnabled = true
|
||||||
|
account.Settings.PeerLoginExpiration = 1 * time.Hour
|
||||||
|
|
||||||
|
expiredTime := time.Now().Add(-2 * time.Hour)
|
||||||
|
account.Peers["peer-dst-1"].LoginExpirationEnabled = true
|
||||||
|
account.Peers["peer-dst-1"].LastLogin = &expiredTime
|
||||||
|
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers")
|
||||||
|
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account, "peer-dst-1")
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded")
|
||||||
|
assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account, "peer-src-1")
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map")
|
||||||
|
assert.Empty(t, nm.FirewallRules)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasResourceRoute bool
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.Network.String() == "10.200.0.1/32" {
|
||||||
|
hasResourceRoute = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router")
|
||||||
|
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||||
|
|
||||||
|
var hasResourceRoute bool
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.Network.String() == "10.200.0.1/32" {
|
||||||
|
hasResourceRoute = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasResourceRoute, "router peer should receive network resource route")
|
||||||
|
assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||||
|
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.PostureChecks = []*posture.Checks{
|
||||||
|
{ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id,
|
||||||
|
SourcePostureChecks: []string{"pc-version"},
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
DestinationResource: types.Resource{ID: "resource-guarded"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||||
|
ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true,
|
||||||
|
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32",
|
||||||
|
})
|
||||||
|
account.Networks = append(account.Networks, &networkTypes.Network{
|
||||||
|
ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id,
|
||||||
|
})
|
||||||
|
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||||
|
ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("peer passes posture check", func(t *testing.T) {
|
||||||
|
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasGuardedRoute bool
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.Network.String() == "10.200.1.1/32" {
|
||||||
|
hasGuardedRoute = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("peer fails posture check", func(t *testing.T) {
|
||||||
|
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.PostureChecks = []*posture.Checks{
|
||||||
|
{ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||||
|
}},
|
||||||
|
{ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{
|
||||||
|
OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id,
|
||||||
|
SourcePostureChecks: []string{"pc-version", "pc-os"},
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
DestinationResource: types.Resource{ID: "resource-strict"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||||
|
ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true,
|
||||||
|
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32",
|
||||||
|
})
|
||||||
|
account.Networks = append(account.Networks, &networkTypes.Network{
|
||||||
|
ID: "net-strict", Name: "Strict Net", AccountID: account.Id,
|
||||||
|
})
|
||||||
|
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||||
|
ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passes both posture checks", func(t *testing.T) {
|
||||||
|
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
|
||||||
|
account.Peers["peer-src-1"].Meta.GoOS = "linux"
|
||||||
|
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var found bool
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.Network.String() == "10.200.2.1/32" {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, found, "peer passing both checks should get resource route")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fails version posture check", func(t *testing.T) {
|
||||||
|
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
|
||||||
|
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fails OS posture check", func(t *testing.T) {
|
||||||
|
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
|
||||||
|
account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0"
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||||
|
|
||||||
|
var resourceFWRules []*types.RouteFirewallRule
|
||||||
|
for _, rule := range nm.RoutesFirewallRules {
|
||||||
|
if rule.Destination == "10.200.0.1/32" {
|
||||||
|
resourceFWRules = append(resourceFWRules, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource")
|
||||||
|
|
||||||
|
var hasSourcePeerIP bool
|
||||||
|
for _, rule := range resourceFWRules {
|
||||||
|
for _, sr := range rule.SourceRanges {
|
||||||
|
if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" {
|
||||||
|
hasSourcePeerIP = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_DNSManagement(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
t.Run("peer in DNS-enabled group", func(t *testing.T) {
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("peer in DNS-disabled group", func(t *testing.T) {
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||||
|
assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NameServerGroups(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
assert.True(t, nm.DNSConfig.ServiceEnable)
|
||||||
|
|
||||||
|
var hasNSGroup bool
|
||||||
|
for _, ns := range nm.DNSConfig.NameServerGroups {
|
||||||
|
if ns.ID == "ns-main" {
|
||||||
|
hasNSGroup = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.Routes["route-ha-1"] = &route.Route{
|
||||||
|
ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
|
||||||
|
Enabled: true, Metric: 100, AccountID: account.Id,
|
||||||
|
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
|
||||||
|
}
|
||||||
|
account.Routes["route-ha-2"] = &route.Route{
|
||||||
|
ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
|
||||||
|
Enabled: true, Metric: 200, AccountID: account.Id,
|
||||||
|
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"},
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
haCount := 0
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.Network.String() == "172.16.0.0/16" {
|
||||||
|
haCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.Routes["route-acl"] = &route.Route{
|
||||||
|
ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"),
|
||||||
|
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
|
||||||
|
Enabled: true, Metric: 100, AccountID: account.Id,
|
||||||
|
Groups: []string{"group-dst"},
|
||||||
|
PeerGroups: []string{"group-src"},
|
||||||
|
AccessControlGroups: []string{"group-dst"},
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasFWRule bool
|
||||||
|
for _, rule := range nm.RoutesFirewallRules {
|
||||||
|
if rule.Destination == "192.168.100.0/24" {
|
||||||
|
hasFWRule = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.Routes["route-open"] = &route.Route{
|
||||||
|
ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||||
|
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
|
||||||
|
Enabled: true, Metric: 100, AccountID: account.Id,
|
||||||
|
Groups: []string{"group-src"},
|
||||||
|
PeerGroups: []string{"group-src"},
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasFWRule bool
|
||||||
|
for _, rule := range nm.RoutesFirewallRules {
|
||||||
|
if rule.Destination == "10.99.0.0/16" {
|
||||||
|
hasFWRule = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.Peers["peer-dst-1"].SSHEnabled = true
|
||||||
|
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-ssh", Name: "SSH to dst", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||||
|
assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
for _, p := range account.Policies {
|
||||||
|
p.Enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers")
|
||||||
|
assert.Empty(t, nm.FirewallRules)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
for _, r := range account.Routes {
|
||||||
|
r.Enabled = false
|
||||||
|
}
|
||||||
|
for _, r := range account.NetworkResources {
|
||||||
|
r.Enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
assert.Empty(t, nm.Routes, "disabled routes should not appear in network map")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
for _, r := range account.NetworkResources {
|
||||||
|
r.Enabled = false
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||||
|
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated)
|
||||||
|
|
||||||
|
assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy")
|
||||||
|
assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_DropPolicy(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-drop", Name: "Drop src->dst", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
Ports: []string{"5432"},
|
||||||
|
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasDropRule bool
|
||||||
|
for _, rule := range nm.FirewallRules {
|
||||||
|
if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" {
|
||||||
|
hasDropRule = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasDropRule, "drop policy should generate drop firewall rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_PortRangePolicy(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0"
|
||||||
|
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-range", Name: "Range rule", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
|
||||||
|
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
|
||||||
|
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasRangeRule bool
|
||||||
|
for _, rule := range nm.FirewallRules {
|
||||||
|
if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 {
|
||||||
|
hasRangeRule = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||||
|
ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
|
||||||
|
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32",
|
||||||
|
})
|
||||||
|
account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"},
|
||||||
|
Resources: []types.Resource{{ID: "resource-2"}},
|
||||||
|
}
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-res2", Name: "Access Resource 2", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
DestinationResource: types.Resource{ID: "resource-2"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||||
|
|
||||||
|
resourceRouteCount := 0
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" {
|
||||||
|
resourceRouteCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||||
|
ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
|
||||||
|
Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com",
|
||||||
|
})
|
||||||
|
account.Groups["group-res-domain"] = &types.Group{
|
||||||
|
ID: "group-res-domain", Name: "Domain Resource Group",
|
||||||
|
Resources: []types.Resource{{ID: "resource-domain"}},
|
||||||
|
}
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-domain", Name: "Access domain resource", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
DestinationResource: types.Resource{ID: "resource-domain"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
|
||||||
|
|
||||||
|
var hasDomainRoute bool
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" {
|
||||||
|
hasDomainRoute = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_NetworkEmpty(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "nonexistent-peer", validated)
|
||||||
|
|
||||||
|
assert.NotNil(t, nm)
|
||||||
|
assert.Empty(t, nm.Peers)
|
||||||
|
assert.Empty(t, nm.FirewallRules)
|
||||||
|
assert.NotNil(t, nm.Network)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) {
|
||||||
|
account := createComponentTestAccount()
|
||||||
|
validated := allPeersValidated(account)
|
||||||
|
|
||||||
|
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
|
||||||
|
ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true,
|
||||||
|
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32",
|
||||||
|
})
|
||||||
|
account.Networks = append(account.Networks, &networkTypes.Network{
|
||||||
|
ID: "net-other", Name: "Other Net", AccountID: account.Id,
|
||||||
|
})
|
||||||
|
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
|
||||||
|
ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id,
|
||||||
|
})
|
||||||
|
account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group",
|
||||||
|
Resources: []types.Resource{{ID: "resource-other"}},
|
||||||
|
}
|
||||||
|
account.Policies = append(account.Policies, &types.Policy{
|
||||||
|
ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-other", Name: "Other resource access", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
DestinationResource: types.Resource{ID: "resource-other"},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
|
||||||
|
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
|
||||||
|
|
||||||
|
for _, r := range nm.Routes {
|
||||||
|
assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createComponentTestAccount() *types.Account {
|
||||||
|
peers := map[string]*nbpeer.Peer{
|
||||||
|
"peer-src-1": {
|
||||||
|
ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||||
|
},
|
||||||
|
"peer-src-2": {
|
||||||
|
ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||||
|
},
|
||||||
|
"peer-dst-1": {
|
||||||
|
ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||||
|
},
|
||||||
|
"peer-router-1": {
|
||||||
|
ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1",
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := map[string]*types.Group{
|
||||||
|
"group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}},
|
||||||
|
"group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}},
|
||||||
|
"group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}},
|
||||||
|
"group-res": {
|
||||||
|
ID: "group-res", Name: "Resource Group",
|
||||||
|
Resources: []types.Resource{{ID: "resource-1"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := []*types.Policy{
|
||||||
|
{
|
||||||
|
ID: "policy-base", Name: "Base connectivity", Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-base", Name: "Allow src <-> dst", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Bidirectional: true,
|
||||||
|
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "policy-resource", Name: "Network resource access", Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{{
|
||||||
|
ID: "rule-resource", Name: "Source -> Resource", Enabled: true,
|
||||||
|
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Sources: []string{"group-src"},
|
||||||
|
DestinationResource: types.Resource{ID: "resource-1"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := map[route.ID]*route.Route{
|
||||||
|
"route-main": {
|
||||||
|
ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"),
|
||||||
|
Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
|
||||||
|
Enabled: true, Metric: 100,
|
||||||
|
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
users := map[string]*types.User{
|
||||||
|
"user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}},
|
||||||
|
"user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &types.Account{
|
||||||
|
Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||||
|
Users: users,
|
||||||
|
Network: &types.Network{
|
||||||
|
Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||||
|
},
|
||||||
|
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}},
|
||||||
|
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||||
|
"ns-main": {
|
||||||
|
ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PostureChecks: []*posture.Checks{},
|
||||||
|
NetworkResources: []*resourceTypes.NetworkResource{
|
||||||
|
{
|
||||||
|
ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true,
|
||||||
|
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Networks: []*networkTypes.Network{
|
||||||
|
{ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"},
|
||||||
|
},
|
||||||
|
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||||
|
{ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"},
|
||||||
|
},
|
||||||
|
Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range account.Policies {
|
||||||
|
p.AccountID = account.Id
|
||||||
|
}
|
||||||
|
for _, r := range account.Routes {
|
||||||
|
r.AccountID = account.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
return account
|
||||||
|
}
|
||||||
@@ -1,967 +0,0 @@
|
|||||||
package types_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
numPeers = 100
|
|
||||||
devGroupID = "group-dev"
|
|
||||||
opsGroupID = "group-ops"
|
|
||||||
allGroupID = "group-all"
|
|
||||||
sshUsersGroupID = "group-ssh-users"
|
|
||||||
routeID = route.ID("route-main")
|
|
||||||
routeHA1ID = route.ID("route-ha-1")
|
|
||||||
routeHA2ID = route.ID("route-ha-2")
|
|
||||||
policyIDDevOps = "policy-dev-ops"
|
|
||||||
policyIDAll = "policy-all"
|
|
||||||
policyIDPosture = "policy-posture"
|
|
||||||
policyIDDrop = "policy-drop"
|
|
||||||
policyIDSSH = "policy-ssh"
|
|
||||||
postureCheckID = "posture-check-ver"
|
|
||||||
networkResourceID = "res-database"
|
|
||||||
networkID = "net-database"
|
|
||||||
networkRouterID = "router-database"
|
|
||||||
nameserverGroupID = "ns-group-main"
|
|
||||||
testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map.
|
|
||||||
expiredPeerID = "peer-98" // This peer will be online but with an expired session.
|
|
||||||
offlinePeerID = "peer-99" // This peer will be completely offline.
|
|
||||||
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
|
|
||||||
testAccountID = "account-golden-test"
|
|
||||||
userAdminID = "user-admin"
|
|
||||||
userDevID = "user-dev"
|
|
||||||
userOpsID = "user-ops"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
if peerID == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
|
||||||
|
|
||||||
if string(legacyJSON) != string(newJSON) {
|
|
||||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
|
|
||||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
|
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
|
||||||
|
|
||||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved new network map to %s", newFilePath)
|
|
||||||
|
|
||||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
var peerIDs []string
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
peerIDs = append(peerIDs, peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("old builder", func(b *testing.B) {
|
|
||||||
for range b.N {
|
|
||||||
for _, peerID := range peerIDs {
|
|
||||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("new builder", func(b *testing.B) {
|
|
||||||
for range b.N {
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
for _, peerID := range peerIDs {
|
|
||||||
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
if peerID == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
|
|
||||||
newPeerID := "peer-new-101"
|
|
||||||
newPeerIP := net.IP{100, 64, 1, 1}
|
|
||||||
newPeer := &nbpeer.Peer{
|
|
||||||
ID: newPeerID,
|
|
||||||
IP: newPeerIP,
|
|
||||||
Key: fmt.Sprintf("key-%s", newPeerID),
|
|
||||||
DNSLabel: "peernew101",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin",
|
|
||||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
|
||||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newPeerID] = newPeer
|
|
||||||
|
|
||||||
if devGroup, exists := account.Groups[devGroupID]; exists {
|
|
||||||
devGroup.Peers = append(devGroup.Peers, newPeerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
|
||||||
allGroup.Peers = append(allGroup.Peers, newPeerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
validatedPeersMap[newPeerID] = struct{}{}
|
|
||||||
|
|
||||||
if account.Network != nil {
|
|
||||||
account.Network.Serial++
|
|
||||||
}
|
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
|
||||||
|
|
||||||
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
|
||||||
require.NoError(t, err, "error adding peer to cache")
|
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
|
||||||
|
|
||||||
if string(legacyJSON) != string(newJSON) {
|
|
||||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
|
|
||||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
|
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
|
||||||
|
|
||||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved new network map to %s", newFilePath)
|
|
||||||
|
|
||||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
var peerIDs []string
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
peerIDs = append(peerIDs, peerID)
|
|
||||||
}
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
newPeerID := "peer-new-101"
|
|
||||||
newPeer := &nbpeer.Peer{
|
|
||||||
ID: newPeerID,
|
|
||||||
IP: net.IP{100, 64, 1, 1},
|
|
||||||
Key: fmt.Sprintf("key-%s", newPeerID),
|
|
||||||
DNSLabel: "peernew101",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin",
|
|
||||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newPeerID] = newPeer
|
|
||||||
account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID)
|
|
||||||
account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID)
|
|
||||||
validatedPeersMap[newPeerID] = struct{}{}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("old builder after add", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
for _, testingPeerID := range peerIDs {
|
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("new builder after add", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
|
||||||
for _, testingPeerID := range peerIDs {
|
|
||||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
if peerID == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
|
|
||||||
newRouterID := "peer-new-router-102"
|
|
||||||
newRouterIP := net.IP{100, 64, 1, 2}
|
|
||||||
newRouter := &nbpeer.Peer{
|
|
||||||
ID: newRouterID,
|
|
||||||
IP: newRouterIP,
|
|
||||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
|
||||||
DNSLabel: "newrouter102",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin",
|
|
||||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
|
||||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newRouterID] = newRouter
|
|
||||||
|
|
||||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
|
||||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
|
||||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := &route.Route{
|
|
||||||
ID: route.ID("route-new-router"),
|
|
||||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
|
||||||
Peer: newRouter.Key,
|
|
||||||
PeerID: newRouterID,
|
|
||||||
Description: "Route from new router",
|
|
||||||
Enabled: true,
|
|
||||||
PeerGroups: []string{opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{devGroupID},
|
|
||||||
AccountID: account.Id,
|
|
||||||
}
|
|
||||||
account.Routes[newRoute.ID] = newRoute
|
|
||||||
|
|
||||||
validatedPeersMap[newRouterID] = struct{}{}
|
|
||||||
|
|
||||||
if account.Network != nil {
|
|
||||||
account.Network.Serial++
|
|
||||||
}
|
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
|
||||||
|
|
||||||
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
|
||||||
require.NoError(t, err, "error adding router to cache")
|
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
|
||||||
|
|
||||||
if string(legacyJSON) != string(newJSON) {
|
|
||||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
|
|
||||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
|
||||||
|
|
||||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved new network map to %s", newFilePath)
|
|
||||||
|
|
||||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
var peerIDs []string
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
peerIDs = append(peerIDs, peerID)
|
|
||||||
}
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
newRouterID := "peer-new-router-102"
|
|
||||||
newRouterIP := net.IP{100, 64, 1, 2}
|
|
||||||
newRouter := &nbpeer.Peer{
|
|
||||||
ID: newRouterID,
|
|
||||||
IP: newRouterIP,
|
|
||||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
|
||||||
DNSLabel: "newrouter102",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin",
|
|
||||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
|
||||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newRouterID] = newRouter
|
|
||||||
|
|
||||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
|
||||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
|
||||||
}
|
|
||||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
|
||||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := &route.Route{
|
|
||||||
ID: route.ID("route-new-router"),
|
|
||||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
|
||||||
Peer: newRouter.Key,
|
|
||||||
PeerID: newRouterID,
|
|
||||||
Description: "Route from new router",
|
|
||||||
Enabled: true,
|
|
||||||
PeerGroups: []string{opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{devGroupID},
|
|
||||||
AccountID: account.Id,
|
|
||||||
}
|
|
||||||
account.Routes[newRoute.ID] = newRoute
|
|
||||||
|
|
||||||
validatedPeersMap[newRouterID] = struct{}{}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("old builder after add", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
for _, testingPeerID := range peerIDs {
|
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("new builder after add", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
|
||||||
for _, testingPeerID := range peerIDs {
|
|
||||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
if peerID == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
|
|
||||||
deletedPeerID := "peer-25"
|
|
||||||
|
|
||||||
delete(account.Peers, deletedPeerID)
|
|
||||||
|
|
||||||
if devGroup, exists := account.Groups[devGroupID]; exists {
|
|
||||||
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
|
|
||||||
return id == deletedPeerID
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
|
||||||
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
|
|
||||||
return id == deletedPeerID
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(validatedPeersMap, deletedPeerID)
|
|
||||||
|
|
||||||
if account.Network != nil {
|
|
||||||
account.Network.Serial++
|
|
||||||
}
|
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
|
||||||
|
|
||||||
err = builder.OnPeerDeleted(account, deletedPeerID)
|
|
||||||
require.NoError(t, err, "error deleting peer from cache")
|
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
|
||||||
|
|
||||||
if string(legacyJSON) != string(newJSON) {
|
|
||||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
|
|
||||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
|
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
|
||||||
|
|
||||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved new network map to %s", newFilePath)
|
|
||||||
|
|
||||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
if peerID == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
|
|
||||||
deletedRouterID := "peer-75"
|
|
||||||
|
|
||||||
var affectedRoute *route.Route
|
|
||||||
for _, r := range account.Routes {
|
|
||||||
if r.PeerID == deletedRouterID {
|
|
||||||
affectedRoute = r
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.NotNil(t, affectedRoute, "Router peer should have a route")
|
|
||||||
|
|
||||||
for _, group := range account.Groups {
|
|
||||||
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
|
|
||||||
return id == deletedRouterID
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
for routeID, r := range account.Routes {
|
|
||||||
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
|
|
||||||
delete(account.Routes, routeID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(account.Peers, deletedRouterID)
|
|
||||||
delete(validatedPeersMap, deletedRouterID)
|
|
||||||
|
|
||||||
if account.Network != nil {
|
|
||||||
account.Network.Serial++
|
|
||||||
}
|
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
|
|
||||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
|
||||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
|
||||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
|
||||||
|
|
||||||
err = builder.OnPeerDeleted(account, deletedRouterID)
|
|
||||||
require.NoError(t, err, "error deleting routing peer from cache")
|
|
||||||
|
|
||||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
normalizeAndSortNetworkMap(newNetworkMap)
|
|
||||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
|
||||||
|
|
||||||
if string(legacyJSON) != string(newJSON) {
|
|
||||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
|
|
||||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
|
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
|
||||||
|
|
||||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Logf("Saved new network map to %s", newFilePath)
|
|
||||||
|
|
||||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
var peerIDs []string
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
peerIDs = append(peerIDs, peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
deletedPeerID := "peer-25"
|
|
||||||
|
|
||||||
delete(account.Peers, deletedPeerID)
|
|
||||||
account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool {
|
|
||||||
return id == deletedPeerID
|
|
||||||
})
|
|
||||||
account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool {
|
|
||||||
return id == deletedPeerID
|
|
||||||
})
|
|
||||||
delete(validatedPeersMap, deletedPeerID)
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("old builder after delete", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
for _, testingPeerID := range peerIDs {
|
|
||||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
b.Run("new builder after delete", func(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
|
||||||
for _, testingPeerID := range peerIDs {
|
|
||||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
|
|
||||||
for _, peer := range networkMap.Peers {
|
|
||||||
if peer.Status != nil {
|
|
||||||
peer.Status.LastSeen = time.Time{}
|
|
||||||
}
|
|
||||||
peer.LastLogin = &time.Time{}
|
|
||||||
}
|
|
||||||
for _, peer := range networkMap.OfflinePeers {
|
|
||||||
if peer.Status != nil {
|
|
||||||
peer.Status.LastSeen = time.Time{}
|
|
||||||
}
|
|
||||||
peer.LastLogin = &time.Time{}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID })
|
|
||||||
sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID })
|
|
||||||
sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID })
|
|
||||||
|
|
||||||
sort.Slice(networkMap.FirewallRules, func(i, j int) bool {
|
|
||||||
r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j]
|
|
||||||
if r1.PeerIP != r2.PeerIP {
|
|
||||||
return r1.PeerIP < r2.PeerIP
|
|
||||||
}
|
|
||||||
if r1.Protocol != r2.Protocol {
|
|
||||||
return r1.Protocol < r2.Protocol
|
|
||||||
}
|
|
||||||
if r1.Direction != r2.Direction {
|
|
||||||
return r1.Direction < r2.Direction
|
|
||||||
}
|
|
||||||
if r1.Action != r2.Action {
|
|
||||||
return r1.Action < r2.Action
|
|
||||||
}
|
|
||||||
return r1.Port < r2.Port
|
|
||||||
})
|
|
||||||
|
|
||||||
sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool {
|
|
||||||
r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j]
|
|
||||||
if r1.RouteID != r2.RouteID {
|
|
||||||
return r1.RouteID < r2.RouteID
|
|
||||||
}
|
|
||||||
if r1.Action != r2.Action {
|
|
||||||
return r1.Action < r2.Action
|
|
||||||
}
|
|
||||||
if r1.Destination != r2.Destination {
|
|
||||||
return r1.Destination < r2.Destination
|
|
||||||
}
|
|
||||||
if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 {
|
|
||||||
if r1.SourceRanges[0] != r2.SourceRanges[0] {
|
|
||||||
return r1.SourceRanges[0] < r2.SourceRanges[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return r1.Port < r2.Port
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, ranges := range networkMap.RoutesFirewallRules {
|
|
||||||
sort.Slice(ranges.SourceRanges, func(i, j int) bool {
|
|
||||||
return ranges.SourceRanges[i] < ranges.SourceRanges[j]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type networkMapJSON struct {
|
|
||||||
Peers []*nbpeer.Peer `json:"Peers"`
|
|
||||||
Network *types.Network `json:"Network"`
|
|
||||||
Routes []*route.Route `json:"Routes"`
|
|
||||||
DNSConfig dns.Config `json:"DNSConfig"`
|
|
||||||
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
|
|
||||||
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
|
|
||||||
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
|
|
||||||
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
|
|
||||||
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
|
|
||||||
EnableSSH bool `json:"EnableSSH"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
|
|
||||||
result := &networkMapJSON{
|
|
||||||
Peers: nm.Peers,
|
|
||||||
Network: nm.Network,
|
|
||||||
Routes: nm.Routes,
|
|
||||||
DNSConfig: nm.DNSConfig,
|
|
||||||
OfflinePeers: nm.OfflinePeers,
|
|
||||||
FirewallRules: nm.FirewallRules,
|
|
||||||
RoutesFirewallRules: nm.RoutesFirewallRules,
|
|
||||||
ForwardingRules: nm.ForwardingRules,
|
|
||||||
EnableSSH: nm.EnableSSH,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nm.AuthorizedUsers) > 0 {
|
|
||||||
result.AuthorizedUsers = make(map[string][]string)
|
|
||||||
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
|
|
||||||
for localUser := range nm.AuthorizedUsers {
|
|
||||||
localUsers = append(localUsers, localUser)
|
|
||||||
}
|
|
||||||
sort.Strings(localUsers)
|
|
||||||
|
|
||||||
for _, localUser := range localUsers {
|
|
||||||
userIDs := nm.AuthorizedUsers[localUser]
|
|
||||||
sortedUserIDs := make([]string, 0, len(userIDs))
|
|
||||||
for userID := range userIDs {
|
|
||||||
sortedUserIDs = append(sortedUserIDs, userID)
|
|
||||||
}
|
|
||||||
sort.Strings(sortedUserIDs)
|
|
||||||
result.AuthorizedUsers[localUser] = sortedUserIDs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTestAccountWithEntities() *types.Account {
|
|
||||||
peers := make(map[string]*nbpeer.Peer)
|
|
||||||
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
|
||||||
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
ip := net.IP{100, 64, 0, byte(i + 1)}
|
|
||||||
wtVersion := "0.25.0"
|
|
||||||
if i%2 == 0 {
|
|
||||||
wtVersion = "0.40.0"
|
|
||||||
}
|
|
||||||
|
|
||||||
p := &nbpeer.Peer{
|
|
||||||
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if peerID == expiredPeerID {
|
|
||||||
p.LoginExpirationEnabled = true
|
|
||||||
pastTimestamp := time.Now().Add(-2 * time.Hour)
|
|
||||||
p.LastLogin = &pastTimestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
peers[peerID] = p
|
|
||||||
allGroupPeers = append(allGroupPeers, peerID)
|
|
||||||
if i < numPeers/2 {
|
|
||||||
devGroupPeers = append(devGroupPeers, peerID)
|
|
||||||
} else {
|
|
||||||
opsGroupPeers = append(opsGroupPeers, peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := map[string]*types.Group{
|
|
||||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
|
||||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
|
||||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
|
||||||
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
|
|
||||||
}
|
|
||||||
|
|
||||||
policies := []*types.Policy{
|
|
||||||
{
|
|
||||||
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
|
|
||||||
Rules: []*types.PolicyRule{{
|
|
||||||
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
|
||||||
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
|
|
||||||
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
|
|
||||||
Rules: []*types.PolicyRule{{
|
|
||||||
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
|
||||||
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false,
|
|
||||||
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
|
|
||||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
|
|
||||||
Rules: []*types.PolicyRule{{
|
|
||||||
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
|
|
||||||
Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
|
|
||||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
|
|
||||||
SourcePostureChecks: []string{postureCheckID},
|
|
||||||
Rules: []*types.PolicyRule{{
|
|
||||||
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
|
||||||
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
|
|
||||||
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
|
|
||||||
Rules: []*types.PolicyRule{{
|
|
||||||
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
|
||||||
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
|
|
||||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
|
||||||
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := map[route.ID]*route.Route{
|
|
||||||
routeID: {
|
|
||||||
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
|
|
||||||
Peer: peers["peer-75"].Key,
|
|
||||||
PeerID: "peer-75",
|
|
||||||
Description: "Route to internal resource", Enabled: true,
|
|
||||||
PeerGroups: []string{devGroupID, opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{devGroupID},
|
|
||||||
},
|
|
||||||
routeHA1ID: {
|
|
||||||
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
|
||||||
Peer: peers["peer-80"].Key,
|
|
||||||
PeerID: "peer-80",
|
|
||||||
Description: "HA Route 1", Enabled: true, Metric: 1000,
|
|
||||||
PeerGroups: []string{allGroupID},
|
|
||||||
Groups: []string{allGroupID},
|
|
||||||
AccessControlGroups: []string{allGroupID},
|
|
||||||
},
|
|
||||||
routeHA2ID: {
|
|
||||||
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
|
||||||
Peer: peers["peer-90"].Key,
|
|
||||||
PeerID: "peer-90",
|
|
||||||
Description: "HA Route 2", Enabled: true, Metric: 900,
|
|
||||||
PeerGroups: []string{devGroupID, opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{allGroupID},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
users := map[string]*types.User{
|
|
||||||
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
|
|
||||||
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
|
|
||||||
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
|
|
||||||
}
|
|
||||||
|
|
||||||
account := &types.Account{
|
|
||||||
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
|
||||||
Users: users,
|
|
||||||
Network: &types.Network{
|
|
||||||
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
|
||||||
},
|
|
||||||
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
|
|
||||||
NameServerGroups: map[string]*dns.NameServerGroup{
|
|
||||||
nameserverGroupID: {
|
|
||||||
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
|
|
||||||
NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PostureChecks: []*posture.Checks{
|
|
||||||
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
|
||||||
}},
|
|
||||||
},
|
|
||||||
NetworkResources: []*resourceTypes.NetworkResource{
|
|
||||||
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
|
|
||||||
},
|
|
||||||
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
|
|
||||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
|
||||||
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
|
|
||||||
},
|
|
||||||
Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range account.Policies {
|
|
||||||
p.AccountID = account.Id
|
|
||||||
}
|
|
||||||
for _, r := range account.Routes {
|
|
||||||
r.AccountID = account.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
return account
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
|
|
||||||
account := createTestAccountWithEntities()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
validatedPeersMap := make(map[string]struct{})
|
|
||||||
for i := range numPeers {
|
|
||||||
peerID := fmt.Sprintf("peer-%d", i)
|
|
||||||
if peerID == offlinePeerID {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
validatedPeersMap[peerID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
|
||||||
|
|
||||||
newRouterID := "peer-new-router-102"
|
|
||||||
newRouterIP := net.IP{100, 64, 1, 2}
|
|
||||||
newRouter := &nbpeer.Peer{
|
|
||||||
ID: newRouterID,
|
|
||||||
IP: newRouterIP,
|
|
||||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
|
||||||
DNSLabel: "newrouter102",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
|
||||||
UserID: "user-admin",
|
|
||||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
|
||||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[newRouterID] = newRouter
|
|
||||||
|
|
||||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
|
||||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
|
||||||
}
|
|
||||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
|
||||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := &route.Route{
|
|
||||||
ID: route.ID("route-new-router"),
|
|
||||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
|
||||||
Peer: newRouter.Key,
|
|
||||||
PeerID: newRouterID,
|
|
||||||
Description: "Route from new router",
|
|
||||||
Enabled: true,
|
|
||||||
PeerGroups: []string{opsGroupID},
|
|
||||||
Groups: []string{devGroupID, opsGroupID},
|
|
||||||
AccessControlGroups: []string{devGroupID},
|
|
||||||
AccountID: account.Id,
|
|
||||||
}
|
|
||||||
account.Routes[newRoute.ID] = newRoute
|
|
||||||
|
|
||||||
validatedPeersMap[newRouterID] = struct{}{}
|
|
||||||
|
|
||||||
if account.Network != nil {
|
|
||||||
account.Network.Serial++
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
|
|
||||||
|
|
||||||
normalizeAndSortNetworkMap(networkMap)
|
|
||||||
|
|
||||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
|
||||||
require.NoError(t, err, "error marshaling network map to JSON")
|
|
||||||
|
|
||||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
|
||||||
|
|
||||||
t.Log("Update golden file with OnPeerAdded router...")
|
|
||||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
|
||||||
require.NoError(t, err, "error reading golden file")
|
|
||||||
|
|
||||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
409
proxy/management_byop_integration_test.go
Normal file
409
proxy/management_byop_integration_test.go
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel/metric/noop"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
grpcstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
"github.com/netbirdio/netbird/management/server/users"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type byopTestSetup struct {
|
||||||
|
store store.Store
|
||||||
|
proxyService *nbgrpc.ProxyServiceServer
|
||||||
|
grpcServer *grpc.Server
|
||||||
|
grpcAddr string
|
||||||
|
cleanup func()
|
||||||
|
|
||||||
|
accountA string
|
||||||
|
accountB string
|
||||||
|
accountAToken types.PlainProxyToken
|
||||||
|
accountBToken types.PlainProxyToken
|
||||||
|
accountACluster string
|
||||||
|
accountBCluster string
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||||
|
t.Helper()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountAID := "byop-account-a"
|
||||||
|
accountBID := "byop-account-b"
|
||||||
|
|
||||||
|
for _, acc := range []*types.Account{
|
||||||
|
{Id: accountAID, Domain: "a.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||||
|
{Id: accountBID, Domain: "b.test.com", DomainCategory: "private", IsDomainPrimaryAccount: true, CreatedAt: time.Now()},
|
||||||
|
} {
|
||||||
|
require.NoError(t, testStore.SaveAccount(ctx, acc))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||||
|
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||||
|
|
||||||
|
clusterA := "byop-a.proxy.test"
|
||||||
|
clusterB := "byop-b.proxy.test"
|
||||||
|
|
||||||
|
services := []*service.Service{
|
||||||
|
{
|
||||||
|
ID: "svc-a1", AccountID: accountAID, Name: "App A1",
|
||||||
|
Domain: "app1." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||||
|
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||||
|
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, Protocol: "http", TargetId: "peer-a1", TargetType: "peer", Enabled: true}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "svc-a2", AccountID: accountAID, Name: "App A2",
|
||||||
|
Domain: "app2." + clusterA, ProxyCluster: clusterA, Enabled: true,
|
||||||
|
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||||
|
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, Protocol: "http", TargetId: "peer-a2", TargetType: "peer", Enabled: true}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "svc-b1", AccountID: accountBID, Name: "App B1",
|
||||||
|
Domain: "app1." + clusterB, ProxyCluster: clusterB, Enabled: true,
|
||||||
|
SessionPrivateKey: privKey, SessionPublicKey: pubKey,
|
||||||
|
Targets: []*service.Target{{Path: strPtr("/"), Host: "10.0.0.3", Port: 8080, Protocol: "http", TargetId: "peer-b1", TargetType: "peer", Enabled: true}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, svc := range services {
|
||||||
|
require.NoError(t, testStore.CreateService(ctx, svc))
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenA, err := types.CreateNewProxyAccessToken("byop-token-a", 0, &accountAID, "admin-a")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenA.ProxyAccessToken))
|
||||||
|
|
||||||
|
tokenB, err := types.CreateNewProxyAccessToken("byop-token-b", 0, &accountBID, "admin-b")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, testStore.SaveProxyAccessToken(ctx, &tokenB.ProxyAccessToken))
|
||||||
|
|
||||||
|
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
|
||||||
|
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
|
||||||
|
|
||||||
|
meter := noop.NewMeterProvider().Meter("test")
|
||||||
|
realProxyManager, err := proxymanager.NewManager(testStore, meter)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
oidcConfig := nbgrpc.ProxyOIDCConfig{
|
||||||
|
Issuer: "https://fake-issuer.example.com",
|
||||||
|
ClientID: "test-client",
|
||||||
|
HMACKey: []byte("test-hmac-key"),
|
||||||
|
}
|
||||||
|
|
||||||
|
usersManager := users.NewManager(testStore)
|
||||||
|
|
||||||
|
proxyService := nbgrpc.NewProxyServiceServer(
|
||||||
|
&testAccessLogManager{},
|
||||||
|
tokenStore,
|
||||||
|
pkceStore,
|
||||||
|
oidcConfig,
|
||||||
|
nil,
|
||||||
|
usersManager,
|
||||||
|
realProxyManager,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
|
||||||
|
proxyService.SetServiceManager(svcMgr)
|
||||||
|
|
||||||
|
proxyController := &testProxyController{}
|
||||||
|
proxyService.SetProxyController(proxyController)
|
||||||
|
|
||||||
|
_, streamInterceptor, authClose := nbgrpc.NewProxyAuthInterceptors(testStore)
|
||||||
|
|
||||||
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
grpcServer := grpc.NewServer(grpc.StreamInterceptor(streamInterceptor))
|
||||||
|
proto.RegisterProxyServiceServer(grpcServer, proxyService)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := grpcServer.Serve(lis); err != nil {
|
||||||
|
t.Logf("gRPC server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &byopTestSetup{
|
||||||
|
store: testStore,
|
||||||
|
proxyService: proxyService,
|
||||||
|
grpcServer: grpcServer,
|
||||||
|
grpcAddr: lis.Addr().String(),
|
||||||
|
cleanup: func() {
|
||||||
|
grpcServer.GracefulStop()
|
||||||
|
authClose()
|
||||||
|
storeCleanup()
|
||||||
|
},
|
||||||
|
accountA: accountAID,
|
||||||
|
accountB: accountBID,
|
||||||
|
accountAToken: tokenA.PlainToken,
|
||||||
|
accountBToken: tokenB.PlainToken,
|
||||||
|
accountACluster: clusterA,
|
||||||
|
accountBCluster: clusterB,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func byopContext(ctx context.Context, token types.PlainProxyToken) context.Context {
|
||||||
|
md := metadata.Pairs("authorization", "Bearer "+string(token))
|
||||||
|
return metadata.NewOutgoingContext(ctx, md)
|
||||||
|
}
|
||||||
|
|
||||||
|
func receiveBYOPMappings(t *testing.T, stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
|
||||||
|
t.Helper()
|
||||||
|
var mappings []*proto.ProxyMapping
|
||||||
|
for {
|
||||||
|
msg, err := stream.Recv()
|
||||||
|
require.NoError(t, err)
|
||||||
|
mappings = append(mappings, msg.GetMapping()...)
|
||||||
|
if msg.GetInitialSyncComplete() {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mappings
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_BYOPProxy_ReceivesOnlyAccountServices(t *testing.T) {
|
||||||
|
setup := setupBYOPIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "byop-proxy-a",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mappings := receiveBYOPMappings(t, stream)
|
||||||
|
|
||||||
|
assert.Len(t, mappings, 2, "BYOP proxy should receive only account A's 2 services")
|
||||||
|
for _, m := range mappings {
|
||||||
|
assert.Equal(t, setup.accountA, m.GetAccountId(), "all mappings should belong to account A")
|
||||||
|
t.Logf("received mapping: id=%s domain=%s account=%s", m.GetId(), m.GetDomain(), m.GetAccountId())
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := map[string]bool{}
|
||||||
|
for _, m := range mappings {
|
||||||
|
ids[m.GetId()] = true
|
||||||
|
}
|
||||||
|
assert.True(t, ids["svc-a1"], "should contain svc-a1")
|
||||||
|
assert.True(t, ids["svc-a2"], "should contain svc-a2")
|
||||||
|
assert.False(t, ids["svc-b1"], "should NOT contain account B's svc-b1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_BYOPProxy_AccountBReceivesOnlyItsServices(t *testing.T) {
|
||||||
|
setup := setupBYOPIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "byop-proxy-b",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountBCluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mappings := receiveBYOPMappings(t, stream)
|
||||||
|
|
||||||
|
assert.Len(t, mappings, 1, "BYOP proxy B should receive only 1 service")
|
||||||
|
assert.Equal(t, "svc-b1", mappings[0].GetId())
|
||||||
|
assert.Equal(t, setup.accountB, mappings[0].GetAccountId())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_BYOPProxy_MultiplePerAccount(t *testing.T) {
|
||||||
|
setup := setupBYOPIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||||
|
defer cancel1()
|
||||||
|
|
||||||
|
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "byop-proxy-a-first",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mappings1 := receiveBYOPMappings(t, stream1)
|
||||||
|
assert.Len(t, mappings1, 2, "first BYOP proxy should receive account A's 2 services")
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "byop-proxy-a-second",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mappings2 := receiveBYOPMappings(t, stream2)
|
||||||
|
assert.Len(t, mappings2, 2, "second BYOP proxy from same account should also receive the 2 services")
|
||||||
|
for _, m := range mappings2 {
|
||||||
|
assert.Equal(t, setup.accountA, m.GetAccountId())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_BYOPProxy_ClusterAddressConflict(t *testing.T) {
|
||||||
|
setup := setupBYOPIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||||
|
defer cancel1()
|
||||||
|
|
||||||
|
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "byop-proxy-a-cluster",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_ = receiveBYOPMappings(t, stream1)
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountBToken), 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "byop-proxy-b-conflict",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = stream2.Recv()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
st, ok := grpcstatus.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, codes.AlreadyExists, st.Code(), "cluster address conflict should return AlreadyExists")
|
||||||
|
t.Logf("expected rejection: %s", st.Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_BYOPProxy_SameProxyReconnects(t *testing.T) {
|
||||||
|
setup := setupBYOPIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
proxyID := "byop-proxy-reconnect"
|
||||||
|
|
||||||
|
ctx1, cancel1 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||||
|
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
firstMappings := receiveBYOPMappings(t, stream1)
|
||||||
|
cancel1()
|
||||||
|
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(byopContext(context.Background(), setup.accountAToken), 5*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: proxyID,
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: setup.accountACluster,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
secondMappings := receiveBYOPMappings(t, stream2)
|
||||||
|
|
||||||
|
assert.Equal(t, len(firstMappings), len(secondMappings), "reconnect should receive same mappings")
|
||||||
|
|
||||||
|
firstIDs := map[string]bool{}
|
||||||
|
for _, m := range firstMappings {
|
||||||
|
firstIDs[m.GetId()] = true
|
||||||
|
}
|
||||||
|
for _, m := range secondMappings {
|
||||||
|
assert.True(t, firstIDs[m.GetId()], "mapping %s should be present on reconnect", m.GetId())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegration_BYOPProxy_UnauthenticatedRejected(t *testing.T) {
|
||||||
|
setup := setupBYOPIntegrationTest(t)
|
||||||
|
defer setup.cleanup()
|
||||||
|
|
||||||
|
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewProxyServiceClient(conn)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||||
|
ProxyId: "no-auth-proxy",
|
||||||
|
Version: "test-v1",
|
||||||
|
Address: "some.cluster.io",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = stream.Recv()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
st, ok := grpcstatus.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, codes.Unauthenticated, st.Code())
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -140,6 +141,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
|||||||
nil,
|
nil,
|
||||||
usersManager,
|
usersManager,
|
||||||
proxyManager,
|
proxyManager,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Use store-backed service manager
|
// Use store-backed service manager
|
||||||
@@ -201,7 +203,7 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
|
|||||||
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
// testProxyManager is a mock implementation of proxy.Manager for testing.
|
||||||
type testProxyManager struct{}
|
type testProxyManager struct{}
|
||||||
|
|
||||||
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *nbproxy.Capabilities) error {
|
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string, _ *string, _ *nbproxy.Capabilities) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,6 +219,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
|
func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -237,6 +243,22 @@ func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) GetAccountProxy(_ context.Context, accountID string) (*nbproxy.Proxy, error) {
|
||||||
|
return nil, fmt.Errorf("proxy not found for account %s", accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *testProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
|
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
|
||||||
type testProxyController struct{}
|
type testProxyController struct{}
|
||||||
|
|
||||||
@@ -290,6 +312,10 @@ func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *storeBackedServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -336,6 +362,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _,
|
|||||||
|
|
||||||
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
|
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
|
||||||
|
|
||||||
|
func (m *storeBackedServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||||
|
return m.store.GetServiceByDomain(ctx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3323,10 +3323,64 @@ components:
|
|||||||
example: false
|
example: false
|
||||||
required:
|
required:
|
||||||
- enabled
|
- enabled
|
||||||
|
ProxyTokenRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description: Human-readable token name
|
||||||
|
example: "my-proxy-token"
|
||||||
|
expires_in:
|
||||||
|
type: integer
|
||||||
|
minimum: 0
|
||||||
|
description: Token expiration in seconds (0 = never expires)
|
||||||
|
example: 0
|
||||||
|
required:
|
||||||
|
- name
|
||||||
|
ProxyToken:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
expires_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
created_at:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
last_used:
|
||||||
|
type: string
|
||||||
|
format: date-time
|
||||||
|
revoked:
|
||||||
|
type: boolean
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- name
|
||||||
|
- created_at
|
||||||
|
- revoked
|
||||||
|
ProxyTokenCreated:
|
||||||
|
type: object
|
||||||
|
description: Returned on creation — plain_token is shown only once
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/ProxyToken'
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
plain_token:
|
||||||
|
type: string
|
||||||
|
description: The plain text token (shown only once)
|
||||||
|
example: "nbx_abc123..."
|
||||||
|
required:
|
||||||
|
- plain_token
|
||||||
ProxyCluster:
|
ProxyCluster:
|
||||||
type: object
|
type: object
|
||||||
description: A proxy cluster represents a group of proxy nodes serving the same address
|
description: A proxy cluster represents a group of proxy nodes serving the same address
|
||||||
properties:
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
description: Unique identifier of a proxy in this cluster
|
||||||
|
example: "chlfq4q5r8kc73b0qjpg"
|
||||||
address:
|
address:
|
||||||
type: string
|
type: string
|
||||||
description: Cluster address used for CNAME targets
|
description: Cluster address used for CNAME targets
|
||||||
@@ -3335,9 +3389,15 @@ components:
|
|||||||
type: integer
|
type: integer
|
||||||
description: Number of proxy nodes connected in this cluster
|
description: Number of proxy nodes connected in this cluster
|
||||||
example: 3
|
example: 3
|
||||||
|
self_hosted:
|
||||||
|
type: boolean
|
||||||
|
description: Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||||
|
example: false
|
||||||
required:
|
required:
|
||||||
|
- id
|
||||||
- address
|
- address
|
||||||
- connected_proxies
|
- connected_proxies
|
||||||
|
- self_hosted
|
||||||
ReverseProxyDomainType:
|
ReverseProxyDomainType:
|
||||||
type: string
|
type: string
|
||||||
description: Type of Reverse Proxy Domain
|
description: Type of Reverse Proxy Domain
|
||||||
@@ -11317,6 +11377,111 @@ paths:
|
|||||||
"$ref": "#/components/responses/forbidden"
|
"$ref": "#/components/responses/forbidden"
|
||||||
'500':
|
'500':
|
||||||
"$ref": "#/components/responses/internal_error"
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
/api/reverse-proxies/clusters/{clusterAddress}:
|
||||||
|
delete:
|
||||||
|
summary: Delete a self-hosted proxy cluster
|
||||||
|
description: Removes all self-hosted (BYOP) proxy registrations for the given cluster address owned by the account.
|
||||||
|
tags: [ Services ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: clusterAddress
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: The address of the proxy cluster
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Proxy cluster deleted successfully
|
||||||
|
content: { }
|
||||||
|
'400':
|
||||||
|
"$ref": "#/components/responses/bad_request"
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'404':
|
||||||
|
"$ref": "#/components/responses/not_found"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
/api/reverse-proxies/proxy-tokens:
|
||||||
|
get:
|
||||||
|
summary: List Proxy Tokens
|
||||||
|
description: Returns all proxy access tokens for the account
|
||||||
|
tags: [ Self-Hosted Proxies ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: A JSON Array of proxy tokens
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/ProxyToken'
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
post:
|
||||||
|
summary: Create a Proxy Token
|
||||||
|
description: Generate an account-scoped proxy access token for self-hosted proxy registration
|
||||||
|
tags: [ Self-Hosted Proxies ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ProxyTokenRequest'
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Proxy token created (plain token shown once)
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ProxyTokenCreated'
|
||||||
|
'400':
|
||||||
|
"$ref": "#/components/responses/bad_request"
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
|
/api/reverse-proxies/proxy-tokens/{tokenId}:
|
||||||
|
delete:
|
||||||
|
summary: Revoke a Proxy Token
|
||||||
|
description: Revoke an account-scoped proxy access token
|
||||||
|
tags: [ Self-Hosted Proxies ]
|
||||||
|
security:
|
||||||
|
- BearerAuth: [ ]
|
||||||
|
- TokenAuth: [ ]
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
name: tokenId
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
description: The unique identifier of the proxy token
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: Token revoked
|
||||||
|
'401':
|
||||||
|
"$ref": "#/components/responses/requires_authentication"
|
||||||
|
'403':
|
||||||
|
"$ref": "#/components/responses/forbidden"
|
||||||
|
'404':
|
||||||
|
"$ref": "#/components/responses/not_found"
|
||||||
|
'500':
|
||||||
|
"$ref": "#/components/responses/internal_error"
|
||||||
/api/reverse-proxies/services:
|
/api/reverse-proxies/services:
|
||||||
get:
|
get:
|
||||||
summary: List all Services
|
summary: List all Services
|
||||||
|
|||||||
@@ -3761,11 +3761,49 @@ type ProxyAccessLogsResponse struct {
|
|||||||
|
|
||||||
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
|
// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address
|
||||||
type ProxyCluster struct {
|
type ProxyCluster struct {
|
||||||
|
// Id Unique identifier of a proxy in this cluster
|
||||||
|
Id string `json:"id"`
|
||||||
|
|
||||||
// Address Cluster address used for CNAME targets
|
// Address Cluster address used for CNAME targets
|
||||||
Address string `json:"address"`
|
Address string `json:"address"`
|
||||||
|
|
||||||
// ConnectedProxies Number of proxy nodes connected in this cluster
|
// ConnectedProxies Number of proxy nodes connected in this cluster
|
||||||
ConnectedProxies int `json:"connected_proxies"`
|
ConnectedProxies int `json:"connected_proxies"`
|
||||||
|
|
||||||
|
// SelfHosted Whether this cluster is a self-hosted (BYOP) proxy managed by the account owner
|
||||||
|
SelfHosted bool `json:"self_hosted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyToken defines model for ProxyToken.
|
||||||
|
type ProxyToken struct {
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
Id string `json:"id"`
|
||||||
|
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Revoked bool `json:"revoked"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyTokenCreated defines model for ProxyTokenCreated.
|
||||||
|
type ProxyTokenCreated struct {
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
Id string `json:"id"`
|
||||||
|
LastUsed *time.Time `json:"last_used,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// PlainToken The plain text token (shown only once)
|
||||||
|
PlainToken string `json:"plain_token"`
|
||||||
|
Revoked bool `json:"revoked"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyTokenRequest defines model for ProxyTokenRequest.
|
||||||
|
type ProxyTokenRequest struct {
|
||||||
|
// ExpiresIn Token expiration in seconds (0 = never expires)
|
||||||
|
ExpiresIn *int `json:"expires_in,omitempty"`
|
||||||
|
|
||||||
|
// Name Human-readable token name
|
||||||
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resource defines model for Resource.
|
// Resource defines model for Resource.
|
||||||
@@ -5127,6 +5165,9 @@ type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
|
|||||||
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
|
// PostApiReverseProxiesDomainsJSONRequestBody defines body for PostApiReverseProxiesDomains for application/json ContentType.
|
||||||
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
|
type PostApiReverseProxiesDomainsJSONRequestBody = ReverseProxyDomainRequest
|
||||||
|
|
||||||
|
// PostApiReverseProxiesProxyTokensJSONRequestBody defines body for PostApiReverseProxiesProxyTokens for application/json ContentType.
|
||||||
|
type PostApiReverseProxiesProxyTokensJSONRequestBody = ProxyTokenRequest
|
||||||
|
|
||||||
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
|
// PostApiReverseProxiesServicesJSONRequestBody defines body for PostApiReverseProxiesServices for application/json ContentType.
|
||||||
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest
|
type PostApiReverseProxiesServicesJSONRequestBody = ServiceRequest
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user