Fix non-port based rules processing. Add rules clean up call.

This commit is contained in:
Givi Khojanashvili
2023-03-20 14:52:03 +04:00
parent 64bf7699e6
commit 559cf2862f
3 changed files with 37 additions and 18 deletions

View File

@@ -39,6 +39,9 @@ const (
// Netbird client for ACL and routing functionality
type Manager interface {
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddFiltering(
ip net.IP,
port *Port,

View File

@@ -50,6 +50,8 @@ func Create() (*Manager, error) {
}
// AddFiltering rule to the firewall
//
// If comment is empty rule ID is used as comment
func (m *Manager) AddFiltering(
ip net.IP,
port *fw.Port,
@@ -59,33 +61,35 @@ func (m *Manager) AddFiltering(
) (fw.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
client := m.client(ip)
ok, err := client.ChainExists("filter", ChainFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
}
if !ok {
if err := client.NewChain("filter", ChainFilterName); err != nil {
return nil, fmt.Errorf("failed to create chain: %s", err)
}
}
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
return nil, fmt.Errorf("invalid port definition")
var pv string
if port != nil && port.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
pv = strconv.Itoa(port.Values[0])
}
pv := strconv.Itoa(port.Values[0])
if port.IsRange {
pv += ":" + strconv.Itoa(port.Values[1])
ruleID := uuid.New().String()
if comment == "" {
comment = ruleID
}
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
return nil, err
}
rule := &Rule{
id: uuid.New().String(),
specs: specs,
v6: ip.To4() == nil,
}
return rule, nil
return &Rule{id: ruleID, specs: specs, v6: ip.To4() == nil}, nil
}
// DeleteRule from the firewall by rule definition
@@ -139,7 +143,10 @@ func (m *Manager) filterRuleSpecs(
if direction == fw.DirectionSrc {
specs = append(specs, "-s", ip.String())
}
specs = append(specs, "-p", "tcp", "--dport", port)
specs = append(specs, "-p", "tcp")
if port != "" {
specs = append(specs, "--dport", port)
}
specs = append(specs, "-j", m.actionToStr(action))
return append(specs, "-m", "comment", "--comment", comment)
}

View File

@@ -1041,6 +1041,11 @@ func (e *Engine) close() {
e.dnsServer.Stop()
}
if e.firewallManager != nil {
if err := e.firewallManager.Reset(); err != nil {
log.Warnf("failed resetting firewall: %v", err)
}
}
}
// applyFirewallRules to the local firewall manager processed by ACL policy.
@@ -1050,16 +1055,19 @@ func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error {
return nil
}
newRules := make([]string, 0)
newRules := make(map[string]struct{}, 0)
for _, r := range rules {
rule := e.protoRuleToFirewallRule(r)
if rule == nil {
continue
}
newRules = append(newRules, rule.GetRuleID())
newRules[rule.GetRuleID()] = struct{}{}
}
for _, ruleID := range newRules {
for ruleID := range e.firewallRules {
if _, ok := newRules[ruleID]; ok {
continue
}
if rule, ok := e.firewallRules[ruleID]; ok {
if err := e.firewallManager.DeleteRule(rule); err != nil {
log.Debug("failed to delete firewall rule: %v", err)
@@ -1079,7 +1087,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
return nil
}
var port firewall.Port
var port *firewall.Port
if r.Port != "" {
split := strings.Split(r.Port, "/")
value, err := strconv.Atoi(split[0])
@@ -1087,6 +1095,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
log.Debug("invalid port, skipping firewall rule")
return nil
}
port = &firewall.Port{}
port.Values = []int{value}
// get protocol from the port suffix if it exists
if len(split) > 1 {
@@ -1124,9 +1133,9 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule
return nil
}
rule, err := e.firewallManager.AddFiltering(ip, &port, direction, action, "")
rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "")
if err != nil {
log.Debug("failed to add firewall rule: %v", err)
log.Debugf("failed to add firewall rule: %v", err)
return nil
}
e.firewallRules[rule.GetRuleID()] = rule