mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Refactor protocol handling for firewall rules, add engine tests
This commit is contained in:
@@ -226,6 +226,10 @@ func (e *Engine) Start() error {
|
||||
e.firewallManager, err = buildFirewallManager()
|
||||
if err != nil {
|
||||
log.Errorf("failed to create firewall manager, ACL policy will not work: %s", err.Error())
|
||||
} else {
|
||||
if err := e.firewallManager.Reset(); err != nil {
|
||||
log.Tracef("failed to reset firewall manager on the start: %v", err.Error())
|
||||
}
|
||||
}
|
||||
e.firewallRules = make(map[string]firewall.Rule)
|
||||
|
||||
@@ -632,9 +636,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
if err := e.applyFirewallRules(networkMap.FirewallRules); err != nil {
|
||||
log.Errorf("failed apply firewall rules, err: %v", err)
|
||||
}
|
||||
e.applyFirewallRules(networkMap.FirewallRules)
|
||||
e.networkSerial = serial
|
||||
return nil
|
||||
}
|
||||
@@ -1049,68 +1051,58 @@ func (e *Engine) close() {
|
||||
}
|
||||
|
||||
// applyFirewallRules to the local firewall manager processed by ACL policy.
|
||||
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
|
||||
func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) {
|
||||
if e.firewallManager == nil {
|
||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
for ruleID, rule := range e.firewallRules {
|
||||
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||
log.Errorf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(e.firewallRules, ruleID)
|
||||
}
|
||||
|
||||
newRules := make(map[string]struct{}, 0)
|
||||
for _, r := range rules {
|
||||
rule := e.protoRuleToFirewallRule(r)
|
||||
if rule == nil {
|
||||
continue
|
||||
}
|
||||
newRules[rule.GetRuleID()] = struct{}{}
|
||||
}
|
||||
|
||||
for ruleID := range e.firewallRules {
|
||||
if _, ok := newRules[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
if rule, ok := e.firewallRules[ruleID]; ok {
|
||||
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||
log.Debugf("failed to delete firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
delete(e.firewallRules, ruleID)
|
||||
if rule := e.protoRuleToFirewallRule(r); rule == nil {
|
||||
log.Errorf("failed to apply firewall rule: %v", r)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
log.Debug("invalid IP address, skipping firewall rule")
|
||||
log.Error("invalid IP address, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
split := strings.Split(r.Port, "/")
|
||||
value, err := strconv.Atoi(split[0])
|
||||
if err != nil {
|
||||
log.Debug("invalid port, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
// port can be empty, so ignore conversion error
|
||||
value, _ := strconv.Atoi(split[0])
|
||||
port = &firewall.Port{}
|
||||
port.Values = []int{value}
|
||||
// get protocol from the port suffix if it exists
|
||||
if len(split) > 1 {
|
||||
switch split[1] {
|
||||
case "tcp":
|
||||
port.Proto = firewall.PortProtocolTCP
|
||||
case "udp":
|
||||
port.Proto = firewall.PortProtocolUDP
|
||||
default:
|
||||
log.Debug("invalid protocol, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
if value != 0 {
|
||||
port.Values = []int{value}
|
||||
}
|
||||
}
|
||||
|
||||
var protocol firewall.Protocol
|
||||
switch r.Protocol {
|
||||
case "tcp":
|
||||
protocol = firewall.ProtocolTCP
|
||||
case "udp":
|
||||
protocol = firewall.ProtocolUDP
|
||||
case "icmp":
|
||||
protocol = firewall.ProtocolICMP
|
||||
default:
|
||||
log.Errorf("invalid protocol, skipping firewall rule: %q", r.Protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
var direction firewall.Direction
|
||||
switch r.Direction {
|
||||
case "src":
|
||||
@@ -1118,7 +1110,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
case "dst":
|
||||
direction = firewall.DirectionDst
|
||||
default:
|
||||
log.Debug("invalid direction, skipping firewall rule")
|
||||
log.Error("invalid direction, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1129,13 +1121,13 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
case "drop":
|
||||
action = firewall.ActionDrop
|
||||
default:
|
||||
log.Debug("invalid action, skipping firewall rule")
|
||||
log.Error("invalid action, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
|
||||
rule, err := e.firewallManager.AddFiltering(ip, protocol, port, direction, action, "")
|
||||
if err != nil {
|
||||
log.Debugf("failed to add firewall rule: %v", err)
|
||||
log.Errorf("failed to add firewall rule: %v", err)
|
||||
return nil
|
||||
}
|
||||
e.firewallRules[rule.GetRuleID()] = rule
|
||||
|
||||
Reference in New Issue
Block a user