Refactor protocol handling for firewall rules, add engine tests

This commit is contained in:
Givi Khojanashvili
2023-03-21 17:56:47 +04:00
parent 0abd05d51e
commit 64ad771099
10 changed files with 263 additions and 112 deletions

View File

@@ -226,6 +226,10 @@ func (e *Engine) Start() error {
e.firewallManager, err = buildFirewallManager()
if err != nil {
log.Errorf("failed to create firewall manager, ACL policy will not work: %s", err.Error())
} else {
if err := e.firewallManager.Reset(); err != nil {
log.Tracef("failed to reset firewall manager on the start: %v", err.Error())
}
}
e.firewallRules = make(map[string]firewall.Rule)
@@ -632,9 +636,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
if err := e.applyFirewallRules(networkMap.FirewallRules); err != nil {
log.Errorf("failed apply firewall rules, err: %v", err)
}
e.applyFirewallRules(networkMap.FirewallRules)
e.networkSerial = serial
return nil
}
@@ -1049,68 +1051,58 @@ func (e *Engine) close() {
}
// applyFirewallRules to the local firewall manager processed by ACL policy.
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) {
if e.firewallManager == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return nil
return
}
for ruleID, rule := range e.firewallRules {
if err := e.firewallManager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
delete(e.firewallRules, ruleID)
}
newRules := make(map[string]struct{}, 0)
for _, r := range rules {
rule := e.protoRuleToFirewallRule(r)
if rule == nil {
continue
}
newRules[rule.GetRuleID()] = struct{}{}
}
for ruleID := range e.firewallRules {
if _, ok := newRules[ruleID]; ok {
continue
}
if rule, ok := e.firewallRules[ruleID]; ok {
if err := e.firewallManager.DeleteRule(rule); err != nil {
log.Debugf("failed to delete firewall rule: %v", err)
continue
}
delete(e.firewallRules, ruleID)
if rule := e.protoRuleToFirewallRule(r); rule == nil {
log.Errorf("failed to apply firewall rule: %v", r)
}
}
return nil
}
func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
log.Debug("invalid IP address, skipping firewall rule")
log.Error("invalid IP address, skipping firewall rule")
return nil
}
var port *firewall.Port
if r.Port != "" {
split := strings.Split(r.Port, "/")
value, err := strconv.Atoi(split[0])
if err != nil {
log.Debug("invalid port, skipping firewall rule")
return nil
}
// port can be empty, so ignore conversion error
value, _ := strconv.Atoi(split[0])
port = &firewall.Port{}
port.Values = []int{value}
// get protocol from the port suffix if it exists
if len(split) > 1 {
switch split[1] {
case "tcp":
port.Proto = firewall.PortProtocolTCP
case "udp":
port.Proto = firewall.PortProtocolUDP
default:
log.Debug("invalid protocol, skipping firewall rule")
return nil
}
if value != 0 {
port.Values = []int{value}
}
}
var protocol firewall.Protocol
switch r.Protocol {
case "tcp":
protocol = firewall.ProtocolTCP
case "udp":
protocol = firewall.ProtocolUDP
case "icmp":
protocol = firewall.ProtocolICMP
default:
log.Errorf("invalid protocol, skipping firewall rule: %q", r.Protocol)
return nil
}
var direction firewall.Direction
switch r.Direction {
case "src":
@@ -1118,7 +1110,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
case "dst":
direction = firewall.DirectionDst
default:
log.Debug("invalid direction, skipping firewall rule")
log.Error("invalid direction, skipping firewall rule")
return nil
}
@@ -1129,13 +1121,13 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
case "drop":
action = firewall.ActionDrop
default:
log.Debug("invalid action, skipping firewall rule")
log.Error("invalid action, skipping firewall rule")
return nil
}
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
rule, err := e.firewallManager.AddFiltering(ip, protocol, port, direction, action, "")
if err != nil {
log.Debugf("failed to add firewall rule: %v", err)
log.Errorf("failed to add firewall rule: %v", err)
return nil
}
e.firewallRules[rule.GetRuleID()] = rule