mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Fix non-port based rules processing. Add rules clean up call.
This commit is contained in:
@@ -39,6 +39,9 @@ const (
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
port *Port,
|
||||
|
||||
@@ -50,6 +50,8 @@ func Create() (*Manager, error) {
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
port *fw.Port,
|
||||
@@ -59,33 +61,35 @@ func (m *Manager) AddFiltering(
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
client := m.client(ip)
|
||||
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||
}
|
||||
}
|
||||
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||
return nil, fmt.Errorf("invalid port definition")
|
||||
|
||||
var pv string
|
||||
if port != nil && port.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
pv = strconv.Itoa(port.Values[0])
|
||||
}
|
||||
pv := strconv.Itoa(port.Values[0])
|
||||
if port.IsRange {
|
||||
pv += ":" + strconv.Itoa(port.Values[1])
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rule := &Rule{
|
||||
id: uuid.New().String(),
|
||||
specs: specs,
|
||||
v6: ip.To4() == nil,
|
||||
}
|
||||
return rule, nil
|
||||
|
||||
return &Rule{id: ruleID, specs: specs, v6: ip.To4() == nil}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
@@ -139,7 +143,10 @@ func (m *Manager) filterRuleSpecs(
|
||||
if direction == fw.DirectionSrc {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||
specs = append(specs, "-p", "tcp")
|
||||
if port != "" {
|
||||
specs = append(specs, "--dport", port)
|
||||
}
|
||||
specs = append(specs, "-j", m.actionToStr(action))
|
||||
return append(specs, "-m", "comment", "--comment", comment)
|
||||
}
|
||||
|
||||
@@ -1041,6 +1041,11 @@ func (e *Engine) close() {
|
||||
e.dnsServer.Stop()
|
||||
}
|
||||
|
||||
if e.firewallManager != nil {
|
||||
if err := e.firewallManager.Reset(); err != nil {
|
||||
log.Warnf("failed resetting firewall: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyFirewallRules to the local firewall manager processed by ACL policy.
|
||||
@@ -1050,16 +1055,19 @@ func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
newRules := make([]string, 0)
|
||||
newRules := make(map[string]struct{}, 0)
|
||||
for _, r := range rules {
|
||||
rule := e.protoRuleToFirewallRule(r)
|
||||
if rule == nil {
|
||||
continue
|
||||
}
|
||||
newRules = append(newRules, rule.GetRuleID())
|
||||
newRules[rule.GetRuleID()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, ruleID := range newRules {
|
||||
for ruleID := range e.firewallRules {
|
||||
if _, ok := newRules[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
if rule, ok := e.firewallRules[ruleID]; ok {
|
||||
if err := e.firewallManager.DeleteRule(rule); err != nil {
|
||||
log.Debug("failed to delete firewall rule: %v", err)
|
||||
@@ -1079,7 +1087,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
return nil
|
||||
}
|
||||
|
||||
var port firewall.Port
|
||||
var port *firewall.Port
|
||||
if r.Port != "" {
|
||||
split := strings.Split(r.Port, "/")
|
||||
value, err := strconv.Atoi(split[0])
|
||||
@@ -1087,6 +1095,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
log.Debug("invalid port, skipping firewall rule")
|
||||
return nil
|
||||
}
|
||||
port = &firewall.Port{}
|
||||
port.Values = []int{value}
|
||||
// get protocol from the port suffix if it exists
|
||||
if len(split) > 1 {
|
||||
@@ -1124,9 +1133,9 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
|
||||
return nil
|
||||
}
|
||||
|
||||
rule, err := e.firewallManager.AddFiltering(ip, &port, direction, action, "")
|
||||
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
|
||||
if err != nil {
|
||||
log.Debug("failed to add firewall rule: %v", err)
|
||||
log.Debugf("failed to add firewall rule: %v", err)
|
||||
return nil
|
||||
}
|
||||
e.firewallRules[rule.GetRuleID()] = rule
|
||||
|
||||
Reference in New Issue
Block a user