diff --git a/client/firewall/firewall.go b/client/firewall/firewall.go index 2e685e15c..ee41d4f9e 100644 --- a/client/firewall/firewall.go +++ b/client/firewall/firewall.go @@ -39,6 +39,9 @@ const ( // Netbird client for ACL and routing functionality type Manager interface { // AddFiltering rule to the firewall + // + // If comment argument is empty firewall manager should set + // rule ID as comment for the rule AddFiltering( ip net.IP, port *Port, diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index e5aafd6d8..31e5b1ff0 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -50,6 +50,8 @@ func Create() (*Manager, error) { } // AddFiltering rule to the firewall +// +// If comment is empty rule ID is used as comment func (m *Manager) AddFiltering( ip net.IP, port *fw.Port, @@ -59,33 +61,35 @@ func (m *Manager) AddFiltering( ) (fw.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() + client := m.client(ip) ok, err := client.ChainExists("filter", ChainFilterName) if err != nil { return nil, fmt.Errorf("failed to check if chain exists: %s", err) } + if !ok { if err := client.NewChain("filter", ChainFilterName); err != nil { return nil, fmt.Errorf("failed to create chain: %s", err) } } - if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) { - return nil, fmt.Errorf("invalid port definition") + + var pv string + if port != nil && port.Values != nil { + // TODO: we support only one port per rule in current implementation of ACLs + pv = strconv.Itoa(port.Values[0]) } - pv := strconv.Itoa(port.Values[0]) - if port.IsRange { - pv += ":" + strconv.Itoa(port.Values[1]) + ruleID := uuid.New().String() + if comment == "" { + comment = ruleID } + specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment) if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil { return nil, err } - rule := &Rule{ - id: uuid.New().String(), - specs: specs, - v6: ip.To4() == nil, - } - return rule, nil + + return &Rule{id: ruleID, specs: specs, v6: ip.To4() == nil}, nil } // DeleteRule from the firewall by rule definition @@ -139,7 +143,10 @@ func (m *Manager) filterRuleSpecs( if direction == fw.DirectionSrc { specs = append(specs, "-s", ip.String()) } - specs = append(specs, "-p", "tcp", "--dport", port) + specs = append(specs, "-p", "tcp") + if port != "" { + specs = append(specs, "--dport", port) + } specs = append(specs, "-j", m.actionToStr(action)) return append(specs, "-m", "comment", "--comment", comment) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 2a98ee81e..d7093ad3c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1041,6 +1041,11 @@ func (e *Engine) close() { e.dnsServer.Stop() } + if e.firewallManager != nil { + if err := e.firewallManager.Reset(); err != nil { + log.Warnf("failed resetting firewall: %v", err) + } + } } // applyFirewallRules to the local firewall manager processed by ACL policy. @@ -1050,16 +1055,19 @@ func (e *Engine) applyFirewallRules(rules []*mgmProto.FirewallRule) error { return nil } - newRules := make([]string, 0) + newRules := make(map[string]struct{}, 0) for _, r := range rules { rule := e.protoRuleToFirewallRule(r) if rule == nil { continue } - newRules = append(newRules, rule.GetRuleID()) + newRules[rule.GetRuleID()] = struct{}{} } - for _, ruleID := range newRules { + for ruleID := range e.firewallRules { + if _, ok := newRules[ruleID]; ok { + continue + } if rule, ok := e.firewallRules[ruleID]; ok { if err := e.firewallManager.DeleteRule(rule); err != nil { log.Debug("failed to delete firewall rule: %v", err) @@ -1079,7 +1087,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule return nil } - var port firewall.Port + var port *firewall.Port if r.Port != "" { split := strings.Split(r.Port, "/") value, err := strconv.Atoi(split[0]) @@ -1087,6 +1095,7 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule log.Debug("invalid port, skipping firewall rule") return nil } + port = &firewall.Port{} port.Values = []int{value} // get protocol from the port suffix if it exists if len(split) > 1 { @@ -1124,9 +1133,9 @@ func (e *Engine) protoRuleToFirewallRule(r *mgmProto.FirewallRule) firewall.Rule return nil } - rule, err := e.firewallManager.AddFiltering(ip, &port, direction, action, "") + rule, err := e.firewallManager.AddFiltering(ip, port, direction, action, "") if err != nil { - log.Debug("failed to add firewall rule: %v", err) + log.Debugf("failed to add firewall rule: %v", err) return nil } e.firewallRules[rule.GetRuleID()] = rule