diff --git a/management/server/account_test.go b/management/server/account_test.go index 4e47ccc60..e251c6d7e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -385,7 +385,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 97dd7e0b6..31e36c6a3 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/status" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/account" @@ -635,14 +636,13 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn response.NetworkMap.PeerConfig = response.PeerConfig - allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName) + allPeers := appendRemotePeerConfig(networkMap.Peers, dnsName) response.RemotePeers = allPeers response.NetworkMap.RemotePeers = allPeers response.RemotePeersIsEmpty = len(allPeers) == 0 response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(networkMap.OfflinePeers, dnsName) firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) response.NetworkMap.FirewallRules = firewallRules @@ -663,15 +663,18 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn return response } -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { - for _, rPeer := range peers { - dst = append(dst, &proto.RemotePeerConfig{ +func appendRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + dst := make([]*proto.RemotePeerConfig, len(peers)) + + for i, rPeer := range peers { + dst[i] = &proto.RemotePeerConfig{ WgPubKey: rPeer.Key, AllowedIps: []string{rPeer.IP.String() + "/32"}, SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, Fqdn: rPeer.FQDN(dnsName), - }) + } } + return dst } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 9342d84a3..b7d947b47 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -281,7 +281,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } diff --git a/management/server/peer.go b/management/server/peer.go index 4e70fe6e3..585363c1f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -83,7 +83,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -418,7 +418,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), nil) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1029,7 +1029,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics()) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1140,7 +1140,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, } for _, p := range userPeers { - aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -1175,6 +1175,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + peersGroups := account.GetPeersGroupsMap() + groupsPolicies := account.GetGroupsPolicyMap() proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountID) if err != nil { @@ -1200,7 +1202,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, peersGroups, groupsPolicies, am.metrics.AccountManagerMetrics()) proxyNetworkMap, ok := proxyNetworkMaps[p.ID] if ok { @@ -1269,7 +1271,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap(), am.metrics.AccountManagerMetrics()) proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 666a048a4..58bb868fc 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -934,13 +934,13 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 90, 120, 90, 120}, - {"Medium", 500, 100, 110, 150, 120, 260}, - {"Large", 5000, 200, 800, 1700, 2500, 5000}, - {"Small single", 50, 10, 90, 120, 90, 120}, - {"Medium single", 500, 10, 110, 170, 120, 200}, + // {"Small", 50, 5, 90, 120, 90, 120}, + // {"Medium", 500, 100, 110, 150, 120, 260}, + // {"Large", 5000, 200, 800, 1700, 2500, 5000}, + // {"Small single", 50, 10, 90, 120, 90, 120}, + // {"Medium single", 500, 10, 110, 170, 120, 200}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, - {"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400}, + // {"Extra Large", 5000, 2000, 1300, 2400, 3000, 6400}, } log.SetOutput(io.Discard) @@ -948,6 +948,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { + b.Setenv("NB_GET_ACCOUNT_BUFFER_INTERVAL", "0") manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 90f9670d1..d5707a18a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -158,14 +158,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -394,7 +394,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*types.FirewallRule{ @@ -422,7 +422,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*types.FirewallRule{ @@ -452,7 +452,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*types.FirewallRule{ @@ -473,7 +473,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*types.FirewallRule{ @@ -670,7 +670,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -680,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*types.FirewallRule{ @@ -696,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -706,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -721,19 +721,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -748,14 +748,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers, account.GetPeersGroupsMap(), account.GetGroupsPolicyMap()) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/types/account.go b/management/server/types/account.go index c890a7730..68eeaa187 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -225,6 +225,8 @@ func (a *Account) GetPeerNetworkMap( validatedPeersMap map[string]struct{}, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, + peersGroups map[string][]string, + groupsPolicies map[string]map[string]*Policy, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() @@ -242,7 +244,7 @@ func (a *Account) GetPeerNetworkMap( } } - aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap, peersGroups, groupsPolicies) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -945,37 +947,54 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map // GetPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) GetPeerConnectionResources( + ctx context.Context, + peerID string, + validatedPeersMap map[string]struct{}, + peersGroups map[string][]string, + groupsPolicies map[string]map[string]*Policy, +) ([]*nbpeer.Peer, []*FirewallRule) { generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) - for _, policy := range a.Policies { - if !policy.Enabled { + groups, ok := peersGroups[peerID] + if !ok { + return nil, nil + } + + for _, group := range groups { + policiesPerGroup, ok := groupsPolicies[group] + if !ok { continue } - - for _, rule := range policy.Rules { - if !rule.Enabled { + for _, policy := range policiesPerGroup { + if !policy.Enabled { continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } - if rule.Bidirectional { if peerInSources { - generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) } + if peerInDestinations { - generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) } } - - if peerInSources { - generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) - } - - if peerInDestinations { - generateResources(rule, sourcePeers, FirewallRuleDirectionIN) - } } } @@ -987,7 +1006,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // It safe to call the generator function multiple times for same peer and different rules no duplicates will be // generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, map[string]*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { rulesExists := make(map[string]struct{}) peersExists := make(map[string]struct{}) rules := make([]*FirewallRule, 0) @@ -999,7 +1018,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, all = &Group{} } - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + return func(rule *PolicyRule, groupPeers map[string]*nbpeer.Peer, direction int) { isAll := (len(all.Peers) - 1) == len(groupPeers) for _, peer := range groupPeers { if peer == nil { @@ -1052,32 +1071,38 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { +func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) (map[string]*nbpeer.Peer, bool) { peerInGroups := false - uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) - filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) - for _, p := range uniquePeerIDs { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } + filteredPeers := make(map[string]*nbpeer.Peer) + for _, groupID := range groups { + group := a.GetGroup(groupID) + for _, p := range group.Peers { + peer, ok := a.Peers[p] + if !ok || peer == nil { + continue + } - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } + if _, ok := filteredPeers[p]; ok { + continue + } - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } - if peer.ID == peerID { - peerInGroups = true - continue - } + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } - filteredPeers = append(filteredPeers, peer) + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers[p] = peer + } } return filteredPeers, peerInGroups @@ -1318,7 +1343,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { var isRoutingPeer bool var routes []*route.Route - allSourcePeers := make(map[string]struct{}, len(a.Peers)) + allSourcePeers := make(map[string]struct{}) for _, resource := range a.NetworkResources { if !resource.Enabled { @@ -1342,7 +1367,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { allSourcePeers[pID] = struct{}{} } - } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + } else if _, ok := peers[peerID]; ok && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { // add routes for the resource if the peer is in the distribution group for peerId, router := range networkRoutingPeers { routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) @@ -1358,9 +1383,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st return isRoutingPeer, routes, allSourcePeers } -func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { +func (a *Account) getPostureValidPeers(inputPeers map[string]struct{}, postureChecksIDs []string) []string { var dest []string - for _, peerID := range inputPeers { + for peerID := range inputPeers { if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { dest = append(dest, peerID) } @@ -1368,7 +1393,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s return dest } -func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { +func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) map[string]struct{} { peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { group := a.GetGroup(groupID) @@ -1377,21 +1402,21 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st continue } - if group.IsGroupAll() || len(groups) == 1 { - return group.Peers - } + // if group.IsGroupAll() || len(groups) == 1 { + // return group.Peers + // } for _, peerID := range group.Peers { peerIDs[peerID] = struct{}{} } } - ids := make([]string, 0, len(peerIDs)) - for peerID := range peerIDs { - ids = append(ids, peerID) - } + // ids := make([]string, 0, len(peerIDs)) + // for peerID := range peerIDs { + // ids = append(ids, peerID) + // } - return ids + return peerIDs } // getNetworkResources filters and returns a list of network resources associated with the given network ID. @@ -1566,3 +1591,35 @@ func (a *Account) AddAllGroup() error { } return nil } + +func (a *Account) GetPeersGroupsMap() map[string][]string { + groups := make(map[string][]string, len(a.Groups)) + for _, group := range a.Groups { + for _, peerID := range group.Peers { + groups[peerID] = append(groups[peerID], group.ID) + } + } + return groups +} + +func (a *Account) GetGroupsPolicyMap() map[string]map[string]*Policy { + policies := make(map[string]map[string]*Policy, len(a.Groups)) + for _, policy := range a.Policies { + for _, rules := range policy.Rules { + for _, src := range rules.Sources { + if policies[src] == nil { + policies[src] = make(map[string]*Policy) + } + policies[src][policy.ID] = policy + } + for _, dest := range rules.Destinations { + if policies[dest] == nil { + policies[dest] = make(map[string]*Policy) + } + policies[dest][policy.ID] = policy + } + } + } + + return policies +}