diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index 19222a607..ce039e9a1 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -40,6 +40,10 @@ type FirewallRule struct { // PortRange represents the range of ports for a firewall rule PortRange RulePortRange + + PeerIPs []string + Ports []string + PortRanges []RulePortRange } // Equal checks if two firewall rules are equal. diff --git a/management/server/types/network.go b/management/server/types/network.go index 7ed13a104..0f45d410a 100644 --- a/management/server/types/network.go +++ b/management/server/types/network.go @@ -78,6 +78,58 @@ func (nm *NetworkMap) UncompactRoutes() { nm.Routes = uncompactedRoutes } +func (nm *NetworkMap) UncompactFirewallRules() { + uncompactedRules := make([]*FirewallRule, 0, len(nm.FirewallRules)*2) + + for _, compactRule := range nm.FirewallRules { + if len(compactRule.PeerIPs) == 0 { + uncompactedRules = append(uncompactedRules, compactRule) + continue + } + + for _, peerIP := range compactRule.PeerIPs { + if len(compactRule.Ports) > 0 { + for _, port := range compactRule.Ports { + expandedRule := &FirewallRule{ + PolicyID: compactRule.PolicyID, + PeerIP: peerIP, + Direction: compactRule.Direction, + Action: compactRule.Action, + Protocol: compactRule.Protocol, + Port: port, + } + uncompactedRules = append(uncompactedRules, expandedRule) + } + } else if len(compactRule.PortRanges) > 0 { + for _, portRange := range compactRule.PortRanges { + expandedRule := &FirewallRule{ + PolicyID: compactRule.PolicyID, + PeerIP: peerIP, + Direction: compactRule.Direction, + Action: compactRule.Action, + Protocol: compactRule.Protocol, + PortRange: portRange, + } + uncompactedRules = append(uncompactedRules, expandedRule) + } + } else { + expandedRule := &FirewallRule{ + PolicyID: compactRule.PolicyID, + PeerIP: peerIP, + Direction: compactRule.Direction, + Action: compactRule.Action, + Protocol: compactRule.Protocol, + Port: compactRule.Port, + PortRange: compactRule.PortRange, + } + uncompactedRules = append(uncompactedRules, expandedRule) + } + } + } + + nm.FirewallRules = uncompactedRules +} + func (nm *NetworkMap) ValidateApplicablePeerIDs(compactNm *NetworkMap, expectedPermsMap map[string]map[string]bool) error { if compactNm == nil { return fmt.Errorf("compact network map is nil") diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go index ddf2fbfe7..41491255f 100644 --- a/management/server/types/networkmap_golden_test.go +++ b/management/server/types/networkmap_golden_test.go @@ -886,6 +886,9 @@ func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { if r1.PeerIP != r2.PeerIP { return r1.PeerIP < r2.PeerIP } + if r1.PolicyID != r2.PolicyID { + return r1.PolicyID < r2.PolicyID + } if r1.Protocol != r2.Protocol { return r1.Protocol < r2.Protocol } @@ -1104,7 +1107,14 @@ func TestGetPeerNetworkMapCompact(t *testing.T) { compactedJSON, err := json.MarshalIndent(compactNm, "", " ") require.NoError(t, err) + compactedBeforeUncompact := filepath.Join("testdata", "compact_before_uncompact.json") + err = os.MkdirAll(filepath.Dir(compactedBeforeUncompact), 0755) + require.NoError(t, err) + err = os.WriteFile(compactedBeforeUncompact, compactedJSON, 0644) + require.NoError(t, err) + compactNm.UncompactRoutes() + compactNm.UncompactFirewallRules() normalizeAndSortNetworkMap(regularNm) normalizeAndSortNetworkMap(compactNm) diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 21ea8373e..bceb48014 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -1144,12 +1144,13 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact( routes = append(routes, crt.route) } - var firewallRules []*FirewallRule + var expandedFirewallRules []*FirewallRule for _, ruleID := range aclView.FirewallRuleIDs { if rule := b.cache.globalRules[ruleID]; rule != nil { - firewallRules = append(firewallRules, rule) + expandedFirewallRules = append(expandedFirewallRules, rule) } } + firewallRules := compactFirewallRules(expandedFirewallRules) var routesFirewallRules []*RouteFirewallRule for _, ruleID := range routesView.RouteFirewallRuleIDs { @@ -1188,6 +1189,97 @@ func splitRouteAndPeer(r *route.Route) (string, string) { return parts[0], parts[1] } +func compactFirewallRules(expandedRules []*FirewallRule) []*FirewallRule { + type peerKey struct { + PolicyID string + PeerIP string + Direction int + Action string + Protocol string + } + + peerGroups := make(map[peerKey]struct { + ports []string + portRanges []RulePortRange + }) + + for _, rule := range expandedRules { + key := peerKey{ + PolicyID: rule.PolicyID, + PeerIP: rule.PeerIP, + Direction: rule.Direction, + Action: rule.Action, + Protocol: rule.Protocol, + } + + group := peerGroups[key] + if rule.Port != "" { + group.ports = append(group.ports, rule.Port) + } + if rule.PortRange.Start != 0 || rule.PortRange.End != 0 { + group.portRanges = append(group.portRanges, rule.PortRange) + } + peerGroups[key] = group + } + + type ruleKey struct { + PolicyID string + Direction int + Action string + Protocol string + PortsSig string + RangesSig string + } + + ruleGroups := make(map[ruleKey]struct { + peerIPs []string + ports []string + portRanges []RulePortRange + }) + + for pKey, pGroup := range peerGroups { + portsSig := strings.Join(pGroup.ports, ",") + rangesSig := fmt.Sprintf("%v", pGroup.portRanges) + + rKey := ruleKey{ + PolicyID: pKey.PolicyID, + Direction: pKey.Direction, + Action: pKey.Action, + Protocol: pKey.Protocol, + PortsSig: portsSig, + RangesSig: rangesSig, + } + + group := ruleGroups[rKey] + group.peerIPs = append(group.peerIPs, pKey.PeerIP) + if len(group.ports) == 0 { + group.ports = pGroup.ports + } + if len(group.portRanges) == 0 { + group.portRanges = pGroup.portRanges + } + ruleGroups[rKey] = group + } + + compactRules := make([]*FirewallRule, 0, len(ruleGroups)) + + for rKey, group := range ruleGroups { + compactRule := &FirewallRule{ + PolicyID: rKey.PolicyID, + Direction: rKey.Direction, + Action: rKey.Action, + Protocol: rKey.Protocol, + PeerIPs: group.peerIPs, + Ports: group.ports, + PortRanges: group.portRanges, + } + + compactRules = append(compactRules, compactRule) + } + + return compactRules +} + func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { var s strings.Builder s.WriteString(fw)