Compare commits

...

1 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
7 changed files with 155 additions and 92 deletions

View File

@@ -385,7 +385,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -635,14 +636,13 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) allPeers := appendRemotePeerConfig(networkMap.Peers, dnsName)
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0 response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) response.NetworkMap.OfflinePeers = appendRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRules = firewallRules
@@ -663,15 +663,18 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
return response return response
} }
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { func appendRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers { dst := make([]*proto.RemotePeerConfig, len(peers))
dst = append(dst, &proto.RemotePeerConfig{
for i, rPeer := range peers {
dst[i] = &proto.RemotePeerConfig{
WgPubKey: rPeer.Key, WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"}, AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName), Fqdn: rPeer.FQDN(dnsName),
}) }
} }
return dst return dst
} }

View File

@@ -281,7 +281,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
} }

View File

@@ -83,7 +83,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -418,7 +418,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err return nil, err
} }
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -1029,7 +1029,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err return nil, nil, nil, err
} }
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -1140,7 +1140,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID { if aclPeer.ID == peerID {
return peer, nil return peer, nil
@@ -1175,6 +1175,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
peersGroups := account.GetPeersGroupsMap()
groupsPolicies := account.GetGroupsPolicyMap()
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID)
if err != nil { if err != nil {
@@ -1200,7 +1202,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, peersGroups, groupsPolicies, am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[p.ID] proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok { if ok {
@@ -1269,7 +1271,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return return
} }
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {

View File

@@ -934,13 +934,13 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
minMsPerOpCICD float64 minMsPerOpCICD float64
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 90, 120, 90, 120}, // {"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260}, // {"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1700, 2500, 5000}, // {"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120}, // {"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200}, // {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400}, // {"Extra Large", 5000, 2000, 1300, 2400, 3000, 6400},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)
@@ -948,6 +948,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases { for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
b.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "0")
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil { if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err) b.Fatalf("Failed to setup test account manager: %v", err)

View File

@@ -158,14 +158,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 7) assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@@ -394,7 +394,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -422,7 +422,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -452,7 +452,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -473,7 +473,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
@@ -670,7 +670,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -680,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -696,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -706,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -721,19 +721,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -748,14 +748,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -225,6 +225,8 @@ func (a *Account) GetPeerNetworkMap(
validatedPeersMap map[string]struct{}, validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy, resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter, routers map[string]map[string]*routerTypes.NetworkRouter,
peersGroups map[string][]string,
groupsPolicies map[string]map[string]*Policy,
metrics *telemetry.AccountManagerMetrics, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap { ) *NetworkMap {
start := time.Now() start := time.Now()
@@ -242,7 +244,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap, peersGroups, groupsPolicies)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@@ -945,37 +947,54 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer // GetPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) GetPeerConnectionResources(
ctx context.Context,
peerID string,
validatedPeersMap map[string]struct{},
peersGroups map[string][]string,
groupsPolicies map[string]map[string]*Policy,
) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies { groups, ok := peersGroups[peerID]
if !policy.Enabled { if !ok {
return nil, nil
}
for _, group := range groups {
policiesPerGroup, ok := groupsPolicies[group]
if !ok {
continue continue
} }
for _, policy := range policiesPerGroup {
for _, rule := range policy.Rules { if !policy.Enabled {
if !rule.Enabled {
continue continue
} }
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) for _, rule := range policy.Rules {
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) if !rule.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
}
}
if rule.Bidirectional {
if peerInSources { if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN) generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
} }
if peerInDestinations { if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
} }
} }
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
} }
} }
@@ -987,7 +1006,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be // It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls. // generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, map[string]*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
@@ -999,7 +1018,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
all = &Group{} all = &Group{}
} }
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { return func(rule *PolicyRule, groupPeers map[string]*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers) isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers { for _, peer := range groupPeers {
if peer == nil { if peer == nil {
@@ -1052,32 +1071,38 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// //
// Important: Posture checks are applicable only to source group peers, // Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) (map[string]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) filteredPeers := make(map[string]*nbpeer.Peer)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) for _, groupID := range groups {
for _, p := range uniquePeerIDs { group := a.GetGroup(groupID)
peer, ok := a.Peers[p] for _, p := range group.Peers {
if !ok || peer == nil { peer, ok := a.Peers[p]
continue if !ok || peer == nil {
} continue
}
// validate the peer based on policy posture checks applied if _, ok := filteredPeers[p]; ok {
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) continue
if !isValid { }
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok { // validate the peer based on policy posture checks applied
continue isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
} if !isValid {
continue
}
if peer.ID == peerID { if _, ok := validatedPeersMap[peer.ID]; !ok {
peerInGroups = true continue
continue }
}
filteredPeers = append(filteredPeers, peer) if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers[p] = peer
}
} }
return filteredPeers, peerInGroups return filteredPeers, peerInGroups
@@ -1318,7 +1343,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool var isRoutingPeer bool
var routes []*route.Route var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers)) allSourcePeers := make(map[string]struct{})
for _, resource := range a.NetworkResources { for _, resource := range a.NetworkResources {
if !resource.Enabled { if !resource.Enabled {
@@ -1342,7 +1367,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{} allSourcePeers[pID] = struct{}{}
} }
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { } else if _, ok := peers[peerID]; ok && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group // add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers { for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
@@ -1358,9 +1383,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers return isRoutingPeer, routes, allSourcePeers
} }
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { func (a *Account) getPostureValidPeers(inputPeers map[string]struct{}, postureChecksIDs []string) []string {
var dest []string var dest []string
for _, peerID := range inputPeers { for peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID) dest = append(dest, peerID)
} }
@@ -1368,7 +1393,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s
return dest return dest
} }
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) map[string]struct{} {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups { for _, groupID := range groups {
group := a.GetGroup(groupID) group := a.GetGroup(groupID)
@@ -1377,21 +1402,21 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st
continue continue
} }
if group.IsGroupAll() || len(groups) == 1 { // if group.IsGroupAll() || len(groups) == 1 {
return group.Peers // return group.Peers
} // }
for _, peerID := range group.Peers { for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{} peerIDs[peerID] = struct{}{}
} }
} }
ids := make([]string, 0, len(peerIDs)) // ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs { // for peerID := range peerIDs {
ids = append(ids, peerID) // ids = append(ids, peerID)
} // }
return ids return peerIDs
} }
// getNetworkResources filters and returns a list of network resources associated with the given network ID. // getNetworkResources filters and returns a list of network resources associated with the given network ID.
@@ -1566,3 +1591,35 @@ func (a *Account) AddAllGroup() error {
} }
return nil return nil
} }
func (a *Account) GetPeersGroupsMap() map[string][]string {
groups := make(map[string][]string, len(a.Groups))
for _, group := range a.Groups {
for _, peerID := range group.Peers {
groups[peerID] = append(groups[peerID], group.ID)
}
}
return groups
}
func (a *Account) GetGroupsPolicyMap() map[string]map[string]*Policy {
policies := make(map[string]map[string]*Policy, len(a.Groups))
for _, policy := range a.Policies {
for _, rules := range policy.Rules {
for _, src := range rules.Sources {
if policies[src] == nil {
policies[src] = make(map[string]*Policy)
}
policies[src][policy.ID] = policy
}
for _, dest := range rules.Destinations {
if policies[dest] == nil {
policies[dest] = make(map[string]*Policy)
}
policies[dest][policy.ID] = policy
}
}
}
return policies
}