Fix non-port based rules processing. Add rules clean up call.

This commit is contained in:
Givi Khojanashvili
2023-03-20 14:52:03 +04:00
parent 64bf7699e6
commit 559cf2862f
3 changed files with 37 additions and 18 deletions

View File

@@ -1041,6 +1041,11 @@ func (e *Engine) close() {
e.dnsServer.Stop()
}
if e.firewallManager != nil {
if err := e.firewallManager.Reset(); err != nil {
log.Warnf("failed resetting firewall: %v", err)
}
}
}
// applyFirewallRules to the local firewall manager processed by ACL policy.
@@ -1050,16 +1055,19 @@ func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
return nil
}
newRules := make([]string, 0)
newRules := make(map[string]struct{}, 0)
for _, r := range rules {
rule := e.protoRuleToFirewallRule(r)
if rule == nil {
continue
}
newRules = append(newRules, rule.GetRuleID())
newRules[rule.GetRuleID()] = struct{}{}
}
for _, ruleID := range newRules {
for ruleID := range e.firewallRules {
if _, ok := newRules[ruleID]; ok {
continue
}
if rule, ok := e.firewallRules[ruleID]; ok {
if err := e.firewallManager.DeleteRule(rule); err != nil {
log.Debug("failed to delete firewall rule: %v", err)
@@ -1079,7 +1087,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
return nil
}
var port firewall.Port
var port *firewall.Port
if r.Port != "" {
split := strings.Split(r.Port, "/")
value, err := strconv.Atoi(split[0])
@@ -1087,6 +1095,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
log.Debug("invalid port, skipping firewall rule")
return nil
}
port = &firewall.Port{}
port.Values = []int{value}
// get protocol from the port suffix if it exists
if len(split) > 1 {
@@ -1124,9 +1133,9 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
return nil
}
rule, err := e.firewallManager.AddFiltering(ip, &port, direction, action, "")
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
if err != nil {
log.Debug("failed to add firewall rule: %v", err)
log.Debugf("failed to add firewall rule: %v", err)
return nil
}
e.firewallRules[rule.GetRuleID()] = rule