refactor networm map generation

This commit is contained in:
Pascal Fischer
2025-03-13 14:29:59 +01:00
parent ab2e3fec72
commit 4d2c774378
7 changed files with 155 additions and 92 deletions

View File

@@ -385,7 +385,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/account"
@@ -635,14 +636,13 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
allPeers := appendRemotePeerConfig(networkMap.Peers, dnsName)
response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
@@ -663,15 +663,18 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
func appendRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
dst := make([]*proto.RemotePeerConfig, len(peers))
for i, rPeer := range peers {
dst[i] = &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
})
}
}
return dst
}

View File

@@ -281,7 +281,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -83,7 +83,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@@ -418,7 +418,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -1029,7 +1029,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -1140,7 +1140,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
}
for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID {
return peer, nil
@@ -1175,6 +1175,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
peersGroups := account.GetPeersGroupsMap()
groupsPolicies := account.GetGroupsPolicyMap()
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID)
if err != nil {
@@ -1200,7 +1202,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, peersGroups, groupsPolicies, am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
@@ -1269,7 +1271,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics())
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {

View File

@@ -934,13 +934,13 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200},
// {"Small", 50, 5, 90, 120, 90, 120},
// {"Medium", 500, 100, 110, 150, 120, 260},
// {"Large", 5000, 200, 800, 1700, 2500, 5000},
// {"Small single", 50, 10, 90, 120, 90, 120},
// {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400},
// {"Extra Large", 5000, 2000, 1300, 2400, 3000, 6400},
}
log.SetOutput(io.Discard)
@@ -948,6 +948,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
b.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "0")
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)

View File

@@ -158,14 +158,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@@ -394,7 +394,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{
@@ -422,7 +422,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{
@@ -452,7 +452,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{
@@ -473,7 +473,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{
@@ -670,7 +670,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -680,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{
@@ -696,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -706,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@@ -721,19 +721,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -748,14 +748,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap())
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -225,6 +225,8 @@ func (a *Account) GetPeerNetworkMap(
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
peersGroups map[string][]string,
groupsPolicies map[string]map[string]*Policy,
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
@@ -242,7 +244,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap)
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap, peersGroups, groupsPolicies)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
@@ -945,37 +947,54 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) GetPeerConnectionResources(
ctx context.Context,
peerID string,
validatedPeersMap map[string]struct{},
peersGroups map[string][]string,
groupsPolicies map[string]map[string]*Policy,
) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
groups, ok := peersGroups[peerID]
if !ok {
return nil, nil
}
for _, group := range groups {
policiesPerGroup, ok := groupsPolicies[group]
if !ok {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
for _, policy := range policiesPerGroup {
if !policy.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
}
}
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
}
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
}
}
@@ -987,7 +1006,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, map[string]*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
@@ -999,7 +1018,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
all = &Group{}
}
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
return func(rule *PolicyRule, groupPeers map[string]*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
@@ -1052,32 +1071,38 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) (map[string]*nbpeer.Peer, bool) {
peerInGroups := false
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
filteredPeers := make(map[string]*nbpeer.Peer)
for _, groupID := range groups {
group := a.GetGroup(groupID)
for _, p := range group.Peers {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := filteredPeers[p]; ok {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
filteredPeers = append(filteredPeers, peer)
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers[p] = peer
}
}
return filteredPeers, peerInGroups
@@ -1318,7 +1343,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers))
allSourcePeers := make(map[string]struct{})
for _, resource := range a.NetworkResources {
if !resource.Enabled {
@@ -1342,7 +1367,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
} else if _, ok := peers[peerID]; ok && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
@@ -1358,9 +1383,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers
}
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
func (a *Account) getPostureValidPeers(inputPeers map[string]struct{}, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
for peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
@@ -1368,7 +1393,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) map[string]struct{} {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
group := a.GetGroup(groupID)
@@ -1377,21 +1402,21 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
// if group.IsGroupAll() || len(groups) == 1 {
// return group.Peers
// }
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
// ids := make([]string, 0, len(peerIDs))
// for peerID := range peerIDs {
// ids = append(ids, peerID)
// }
return ids
return peerIDs
}
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
@@ -1566,3 +1591,35 @@ func (a *Account) AddAllGroup() error {
}
return nil
}
func (a *Account) GetPeersGroupsMap() map[string][]string {
groups := make(map[string][]string, len(a.Groups))
for _, group := range a.Groups {
for _, peerID := range group.Peers {
groups[peerID] = append(groups[peerID], group.ID)
}
}
return groups
}
func (a *Account) GetGroupsPolicyMap() map[string]map[string]*Policy {
policies := make(map[string]map[string]*Policy, len(a.Groups))
for _, policy := range a.Policies {
for _, rules := range policy.Rules {
for _, src := range rules.Sources {
if policies[src] == nil {
policies[src] = make(map[string]*Policy)
}
policies[src][policy.ID] = policy
}
for _, dest := range rules.Destinations {
if policies[dest] == nil {
policies[dest] = make(map[string]*Policy)
}
policies[dest][policy.ID] = policy
}
}
}
return policies
}