diff --git a/client/firewall/port.go b/client/firewall/port.go index 65a16a16e..7681f29c3 100644 --- a/client/firewall/port.go +++ b/client/firewall/port.go @@ -1,5 +1,9 @@ package firewall +import ( + "strconv" +) + // Protocol is the protocol of the port type Protocol string @@ -28,3 +32,15 @@ type Port struct { // Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports Values []int } + +// String interface implementation +func (p *Port) String() string { + var ports string + for _, port := range p.Values { + if ports != "" { + ports += "," + } + ports += strconv.Itoa(port) + } + return ports +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 83d0709f7..95f2c253e 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -1,10 +1,13 @@ package acl import ( + "crypto/md5" + "encoding/hex" "fmt" "net" "strconv" "sync" + "time" log "github.com/sirupsen/logrus" @@ -42,6 +45,17 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { d.mutex.Lock() defer d.mutex.Unlock() + start := time.Now() + defer func() { + total := 0 + for _, pairs := range d.rulesPairs { + total += len(pairs) + } + log.Infof( + "ACL rules processed in: %v, total rules count: %d", + time.Since(start), total) + }() + if d.manager == nil { log.Debug("firewall manager is not supported, skipping firewall rules") return @@ -95,13 +109,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { applyFailed := false newRulePairs := make(map[string][]firewall.Rule) for _, r := range rules { - rulePair, err := d.protoRuleToFirewallRule(r) + pairID, rulePair, err := d.protoRuleToFirewallRule(r) if err != nil { log.Errorf("failed to apply firewall rule: %+v, %v", r, err) applyFailed = true break } - newRulePairs[rulePair[0].GetRuleID()] = rulePair + newRulePairs[pairID] = rulePair } if applyFailed { log.Error("failed to apply firewall rules, rollback ACL to previous state") @@ -140,33 +154,38 @@ func (d *DefaultManager) Stop() { } } -func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]firewall.Rule, error) { +func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { - return nil, fmt.Errorf("invalid IP address, skipping firewall rule") + return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") } protocol := convertToFirewallProtocol(r.Protocol) if protocol == firewall.ProtocolUnknown { - return nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol) + return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol) } action := convertFirewallAction(r.Action) if action == firewall.ActionUnknown { - return nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action) + return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action) } var port *firewall.Port if r.Port != "" { value, err := strconv.Atoi(r.Port) if err != nil { - return nil, fmt.Errorf("invalid port, skipping firewall rule") + return "", nil, fmt.Errorf("invalid port, skipping firewall rule") } port = &firewall.Port{ Values: []int{value}, } } + ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "") + if rulesPair, ok := d.rulesPairs[ruleID]; ok { + return ruleID, rulesPair, nil + } + var rules []firewall.Rule var err error switch r.Direction { @@ -175,15 +194,15 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) ([]fi case mgmProto.FirewallRule_OUT: rules, err = d.addOutRules(ip, protocol, port, action, "") default: - return nil, fmt.Errorf("invalid direction, skipping firewall rule") + return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") } if err != nil { - return nil, err + return "", nil, err } - d.rulesPairs[rules[0].GetRuleID()] = rules - return rules, nil + d.rulesPairs[ruleID] = rules + return ruleID, rules, nil } func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) { @@ -226,6 +245,23 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port return append(rules, rule), nil } +// getRuleID() returns unique ID for the rule based on its parameters. +func (d *DefaultManager) getRuleID( + ip net.IP, + proto firewall.Protocol, + direction int, + port *firewall.Port, + action firewall.Action, + comment string, +) string { + idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment + if port != nil { + idStr += port.String() + } + + return hex.EncodeToString(md5.New().Sum([]byte(idStr))) +} + // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type // to all peers in the network map to one rule which just accepts that type of the traffic. // @@ -235,7 +271,7 @@ func (d *DefaultManager) squashAcceptRules( networkMap *mgmProto.NetworkMap, ) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { totalIPs := 0 - for _, p := range networkMap.RemotePeers { + for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for range p.AllowedIps { totalIPs++ } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 366877337..d765e5c6c 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -55,6 +55,11 @@ func TestDefaultManager(t *testing.T) { }) t.Run("add extra rules", func(t *testing.T) { + existedPairs := map[string]struct{}{} + for id := range acl.rulesPairs { + existedPairs[id] = struct{}{} + } + // remove first rule networkMap.FirewallRules = networkMap.FirewallRules[1:] networkMap.FirewallRules = append( @@ -67,11 +72,6 @@ func TestDefaultManager(t *testing.T) { }, ) - existedRulesID := map[string]struct{}{} - for id := range acl.rulesPairs { - existedRulesID[id] = struct{}{} - } - acl.ApplyFiltering(networkMap) // we should have one old and one new rule in the existed rules @@ -80,13 +80,16 @@ func TestDefaultManager(t *testing.T) { return } - // check that old rules was removed - for id := range existedRulesID { - if _, ok := acl.rulesPairs[id]; ok { - t.Errorf("old rule was not removed") - return + // check that old rule was removed + previousCount := 0 + for id := range acl.rulesPairs { + if _, ok := existedPairs[id]; ok { + previousCount++ } } + if previousCount != 1 { + t.Errorf("old rule was not removed") + } }) t.Run("handle default rules", func(t *testing.T) {