Compare commits

...

6 Commits

Author SHA1 Message Date
Maycon Santos
d6c5f5ead8 fix signature 2024-08-07 16:49:18 +02:00
Maycon Santos
afbc0e65d7 refactor duplicated code 2024-08-07 16:46:21 +02:00
Maycon Santos
31dd04e835 fix map init with peers 2024-08-07 16:39:31 +02:00
Maycon Santos
b6a8b1dbcd add docs 2024-08-07 15:30:04 +02:00
Maycon Santos
a29862182a update peers handlers 2024-08-07 15:28:05 +02:00
Maycon Santos
ec4469f43d Use policy expanded peers map from src/dest groups
Pre expand the peers from policy rules source and destination groups
to avoid extra allocation when calculating network map
2024-08-07 15:22:01 +02:00
6 changed files with 118 additions and 68 deletions

View File

@@ -413,6 +413,7 @@ func (a *Account) GetPeerNetworkMap(
peersCustomZone nbdns.CustomZone, peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{}, validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics, metrics *telemetry.AccountManagerMetrics,
expandedPolicies PolicyRuleExpandedPeers,
) *NetworkMap { ) *NetworkMap {
start := time.Now() start := time.Now()
@@ -429,7 +430,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap, expandedPolicies)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer

View File

@@ -412,7 +412,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
} }
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) policyExpandedPeers := account.GetPolicyExpandedPeers()
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil, policyExpandedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }

View File

@@ -64,19 +64,28 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
groupsInfo := toGroupsInfo(account.Groups, peer.ID) groupsInfo := toGroupsInfo(account.Groups, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account) accessiblePeers, valid, err := h.getAccessibleAndValidStatus(ctx, account, peerID, dnsDomain, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) getAccessibleAndValidStatus(ctx context.Context, account *server.Account, peerID string, dnsDomain string, peer *nbpeer.Peer) ([]api.AccessiblePeer, bool, error) {
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return nil, false, err
}
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain()) customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil) policyExpandedPeers := account.GetPolicyExpandedPeers()
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil, policyExpandedPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID] _, valid := validPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) return accessiblePeers, valid, nil
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@@ -110,19 +119,13 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(account) accessiblePeers, valid, err := h.getAccessibleAndValidStatus(ctx, account, peerID, dnsDomain, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w) util.WriteError(ctx, fmt.Errorf("internal error"), w)
return return
} }
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
} }

View File

@@ -87,8 +87,9 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
} }
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
policyExpandedPeers := account.GetPolicyExpandedPeers()
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap, policyExpandedPeers)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -324,7 +325,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err return nil, err
} }
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil policyExpandedPeers := account.GetPolicyExpandedPeers()
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil, policyExpandedPeers), nil
} }
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
@@ -538,7 +540,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
postureChecks := am.getPeerPostureChecks(account, peer) postureChecks := am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) policyExpandedPeers := account.GetPolicyExpandedPeers()
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers)
return newPeer, networkMap, postureChecks, nil return newPeer, networkMap, postureChecks, nil
} }
@@ -595,7 +598,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil policyExpandedPeers := account.GetPolicyExpandedPeers()
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
@@ -743,7 +747,8 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
postureChecks = am.getPeerPostureChecks(account, peer) postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil policyExpandedPeers := account.GetPolicyExpandedPeers()
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), policyExpandedPeers), postureChecks, nil
} }
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
@@ -896,8 +901,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, err return nil, err
} }
policyExpandedPeers := account.GetPolicyExpandedPeers()
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap, policyExpandedPeers)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID { if aclPeer.ID == peerID {
return peer, nil return peer, nil
@@ -939,7 +945,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
dnsCache := &DNSConfigCache{} dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
expandedPolicies := account.GetPolicyExpandedPeers()
for _, peer := range peers { for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) { if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -953,7 +959,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer func() { <-semaphore }() defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p) postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics(), expandedPolicies)
update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer) }(peer)

View File

@@ -212,20 +212,20 @@ type FirewallRule struct {
// getPeerConnectionResources for a given peer // getPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}, expandedPolicies PolicyRuleExpandedPeers) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
continue continue
} }
for _, rule := range policy.Rules { for n, rule := range policy.Rules {
if !rule.Enabled { if !rule.Enabled {
continue continue
} }
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].sourcePeers, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, expandedPolicies[policy.ID][n].destinationPeers, peerID, nil, validatedPeersMap)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@@ -490,38 +490,26 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
// //
// Important: Posture checks are applicable only to source group peers, // Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, peerMap peerMap, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) filteredPeers := make([]*nbpeer.Peer, 0, len(peerMap))
for _, g := range groups { for _, peer := range peerMap {
group, ok := a.Groups[g]
if !ok { if _, ok := validatedPeersMap[peer.ID]; !ok {
continue continue
} }
for _, p := range group.Peers { isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
peer, ok := a.Peers[p] if !isValid {
if !ok || peer == nil { continue
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
} }
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
} }
return filteredPeers, peerInGroups return filteredPeers, peerInGroups
} }
@@ -560,3 +548,45 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
} }
return nil return nil
} }
type expandedRuleGroups struct {
sourcePeers peerMap
destinationPeers peerMap
}
type peerMap map[string]*nbpeer.Peer
// PolicyRuleExpandedPeers is a map with the peers of each policy rule source and destination groups
type PolicyRuleExpandedPeers map[string]map[int]expandedRuleGroups
// GetPolicyExpandedPeers returns a map with the peers of each policy rule source and destination groups
func (a *Account) GetPolicyExpandedPeers() PolicyRuleExpandedPeers {
policyMap := make(PolicyRuleExpandedPeers)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
ruleMap := make(map[int]expandedRuleGroups)
policyMap[policy.ID] = ruleMap
for ruleID, rule := range policy.Rules {
policyMap[policy.ID][ruleID] = expandedRuleGroups{
sourcePeers: make(peerMap),
destinationPeers: make(peerMap),
}
a.processGroups(rule.Sources, policyMap[policy.ID][ruleID].sourcePeers)
a.processGroups(rule.Destinations, policyMap[policy.ID][ruleID].destinationPeers)
}
}
return policyMap
}
func (a *Account) processGroups(groupIDs []string, peerMap peerMap) {
for _, gid := range groupIDs {
for _, pid := range a.Groups[gid].Peers {
p, ok := a.Peers[pid]
if ok {
peerMap[pid] = p
}
}
}
}

View File

@@ -143,15 +143,17 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
} }
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
policyExpandedPeers := account.GetPolicyExpandedPeers()
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers, policyExpandedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) policyExpandedPeers := account.GetPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers, policyExpandedPeers)
assert.Len(t, peers, 7) assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@@ -387,7 +389,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) policyExpandedPeers := account.GetPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@@ -415,7 +418,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) policyExpandedPeers := account.GetPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@@ -445,7 +449,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) policyExpandedPeers := account.GetPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@@ -466,7 +471,8 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) policyExpandedPeers := account.GetPolicyExpandedPeers()
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@@ -661,9 +667,10 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
approvedPeers[p] = struct{}{} approvedPeers[p] = struct{}{}
} }
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
policyExpandedPeers := account.GetPolicyExpandedPeers()
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -673,7 +680,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*FirewallRule{ expectedFirewallRules := []*FirewallRule{
@@ -689,7 +696,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -699,7 +706,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -711,22 +718,23 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with modified group peer list", func(t *testing.T) { t.Run("verify peer's network map with modified group peer list", func(t *testing.T) {
// Removing peerB as the part of destination group Swarm // Removing peerB as the part of destination group Swarm
account.Groups["GroupSwarm"].Peers = []string{"peerA", "peerD", "peerE", "peerG", "peerH"} account.Groups["GroupSwarm"].Peers = []string{"peerA", "peerD", "peerE", "peerG", "peerH"}
policyExpandedPeers := account.GetPolicyExpandedPeers()
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -738,17 +746,18 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// Removing peerF as the part of source group All // Removing peerF as the part of source group All
account.Groups["GroupAll"].Peers = []string{"peerB", "peerA", "peerD", "peerC", "peerG", "peerH"} account.Groups["GroupAll"].Peers = []string{"peerB", "peerA", "peerD", "peerC", "peerG", "peerH"}
policyExpandedPeers = account.GetPolicyExpandedPeers()
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers, policyExpandedPeers)
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])