diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 0c1160cda..679ec3b86 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -58,6 +58,11 @@ func TestAccount_getPeersByPolicy(t *testing.T) { IP: net.ParseIP("100.65.29.55"), Status: &nbpeer.PeerStatus{}, }, + "peerI": { + ID: "peerI", + IP: net.ParseIP("100.65.31.2"), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*types.Group{ "GroupAll": { @@ -99,6 +104,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) { "peerH", }, }, + "GroupDMZ": { + ID: "GroupDMZ", + Name: "dmz", + Peers: []string{ + "peerI", + }, + }, }, Policies: []*types.Policy{ { @@ -148,6 +160,35 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, }, }, + { + ID: "RuleDMZ", + Name: "Dmz", + Description: "No description", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "RuleDMZ", + Name: "Dmz", + Description: "No description", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ + { + Start: 8080, + End: 8083, + }, + }, + Sources: []string{ + "GroupWorkstations", + }, + Destinations: []string{ + "GroupDMZ", + }, + }, + }, + }, }, } @@ -166,7 +207,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) { peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) - assert.Len(t, peers, 7) + assert.Len(t, peers, 8) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) @@ -174,8 +215,9 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerF"]) assert.Contains(t, peers, account.Peers["peerG"]) assert.Contains(t, peers, account.Peers["peerH"]) + assert.Contains(t, peers, account.Peers["peerI"]) - epectedFirewallRules := []*types.FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", Direction: types.FirewallRuleDirectionIN, @@ -292,12 +334,28 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Port: "", PolicyID: "RuleSwarm", }, + { + PeerIP: "100.65.31.2", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "tcp", + PortRange: types.RulePortRange{Start: 8080, End: 8083}, + PolicyID: "RuleDMZ", + }, + { + PeerIP: "100.65.31.2", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + PortRange: types.RulePortRange{Start: 8080, End: 8083}, + PolicyID: "RuleDMZ", + }, } - assert.Len(t, firewallRules, len(epectedFirewallRules)) + assert.Len(t, firewallRules, len(expectedFirewallRules)) for _, rule := range firewallRules { contains := false - for _, expectedRule := range epectedFirewallRules { + for _, expectedRule := range expectedFirewallRules { if rule.Equal(expectedRule) { contains = true break diff --git a/management/server/types/account.go b/management/server/types/account.go index 8315f5796..da230f0b2 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1046,7 +1046,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, } rulesExists[ruleID] = struct{}{} - if len(rule.Ports) == 0 { + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { rules = append(rules, &fr) continue } @@ -1056,6 +1056,12 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, pr.Port = port rules = append(rules, &pr) } + + for _, portRange := range rule.PortRanges { + pr := fr + pr.PortRange = portRange + rules = append(rules, &pr) + } } }, func() ([]*nbpeer.Peer, []*FirewallRule) { return peers, rules