diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index ad25494c7..48c95d81b 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -193,7 +193,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin if c.experimentalNetworkMap(accountID) { remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, account.GetResourceMap(), c.accountManagerMetrics) } c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -307,7 +307,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe if c.experimentalNetworkMap(accountId) { remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, account.GetResourceMap(), c.accountManagerMetrics) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -446,7 +446,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr if c.experimentalNetworkMap(accountID) { networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetResourceMap(), c.accountManagerMetrics) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -764,7 +764,7 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N if c.experimentalNetworkMap(peer.AccountID) { networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetResourceMap(), nil) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] diff --git a/management/server/account_test.go b/management/server/account_test.go index 10d718bbf..ca3b6736e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -394,7 +394,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetResourceMap(), nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index c4c5ae165..8ce8e8e99 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -323,7 +323,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), account.GetResourceMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } diff --git a/management/server/policy.go b/management/server/policy.go index 3e84c3d10..7d23925b8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -75,9 +75,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) - } + updateAccountPeers = updateAccountPeers + // if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + // } return policy, nil } diff --git a/management/server/types/account.go b/management/server/types/account.go index 8797e1fa3..9d7e56319 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -262,6 +262,7 @@ func (a *Account) GetPeerNetworkMap( validatedPeersMap map[string]struct{}, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, + resources map[string]*resourceTypes.NetworkResource, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() @@ -293,10 +294,10 @@ func (a *Account) GetPeerNetworkMap( routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) + isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers, resources) var networkResourcesFirewallRules []*RouteFirewallRule if isRouter { - networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) + networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies, routers, resources) } peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) @@ -1325,7 +1326,7 @@ func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []strin } // GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account. -func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule { +func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, resources map[string]*resourceTypes.NetworkResource) []*RouteFirewallRule { routesFirewallRules := make([]*RouteFirewallRule, 0) for _, route := range routes { @@ -1333,7 +1334,7 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer continue } resourceAppliedPolicies := resourcePolicies[string(route.GetResourceID())] - distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) + distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups, routers, resources) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) for _, rule := range rules { @@ -1376,7 +1377,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { } // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, resources map[string]*resourceTypes.NetworkResource) (bool, []*route.Route, map[string]struct{}) { var isRoutingPeer bool var routes []*route.Route allSourcePeers := make(map[string]struct{}, len(a.Peers)) @@ -1399,11 +1400,32 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { var peers []string - if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + switch { + case policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "": peers = []string{policy.Rules[0].SourceResource.ID} - } else { + case policy.Rules[0].SourceResource.Type != "" && policy.Rules[0].SourceResource.ID != "": + if sourceResource, exists := resources[policy.Rules[0].SourceResource.ID]; exists { + sourceRoutingPeers, exists := routers[sourceResource.NetworkID] + if exists { + for _, router := range sourceRoutingPeers { + if router.Peer != "" { + peers = append(peers, router.Peer) + } + if router.PeerGroups != nil { + for _, groupID := range router.PeerGroups { + group := a.GetGroup(groupID) + if group != nil { + peers = append(peers, group.Peers...) + } + } + } + } + } + } + default: peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) } + if addSourcePeers { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { allSourcePeers[pID] = struct{}{} @@ -1571,8 +1593,18 @@ func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.Net return routers } +func (a *Account) GetResourceMap() map[string]*resourceTypes.NetworkResource { + resources := make(map[string]*resourceTypes.NetworkResource) + + for _, resource := range a.NetworkResources { + resources[resource.ID] = resource + } + + return resources +} + // getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies. -func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} { +func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group, routers map[string]map[string]*routerTypes.NetworkRouter, resources map[string]*resourceTypes.NetworkResource) map[string]struct{} { sourcePeers := make(map[string]struct{}) for _, policy := range policies { @@ -1587,6 +1619,28 @@ func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[st sourcePeers[peer] = struct{}{} } } + if (rule.SourceResource.Type == ResourceTypeHost || rule.SourceResource.Type == ResourceTypeDomain || rule.SourceResource.Type == ResourceTypeSubnet) && rule.SourceResource.ID != "" { + if resource, ok := resources[rule.SourceResource.ID]; ok { + if networkRouters, exists := routers[resource.NetworkID]; exists { + for _, router := range networkRouters { + if router.Peer != "" { + sourcePeers[router.Peer] = struct{}{} + } + if router.PeerGroups != nil { + for _, groupID := range router.PeerGroups { + group := groups[groupID] + if group != nil { + for _, peerID := range group.Peers { + sourcePeers[peerID] = struct{}{} + } + } + } + } + } + } + } + + } } } diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 6361e2e93..61900d074 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -31,13 +31,13 @@ const ( type NetworkMapCache struct { globalRoutes map[route.ID]*route.Route - globalRules map[string]*FirewallRule //ruleId - globalRouteRules map[string]*RouteFirewallRule //ruleId + globalRules map[string]*FirewallRule // ruleId + globalRouteRules map[string]*RouteFirewallRule // ruleId globalPeers map[string]*nbpeer.Peer groupToPeers map[string][]string peerToGroups map[string][]string - policyToRules map[string][]*PolicyRule //policyId + policyToRules map[string][]*PolicyRule // policyId groupToPolicies map[string][]*Policy groupToRoutes map[string][]*route.Route peerToRoutes map[string][]*route.Route @@ -651,7 +651,7 @@ func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) } if len(networkResourcesRoutes) > 0 { - networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) + networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies, account.GetResourceRoutersMap(), account.GetResourceMap()) for _, rule := range networkResourceFirewallRules { ruleID := b.generateRouteFirewallRuleID(rule) view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID)