From 006ba32086244b87430a3c1e58f7a53a9aae4fe2 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 8 Dec 2023 10:48:21 +0100 Subject: [PATCH] Fix/acl for forward (#1305) Fix ACL on routed traffic and code refactor --- client/firewall/create.go | 32 + client/firewall/create_linux.go | 107 ++ client/firewall/iface.go | 11 + client/firewall/iptables/acl_linux.go | 473 +++++++ client/firewall/iptables/manager_linux.go | 475 ++----- .../firewall/iptables/manager_linux_test.go | 82 +- client/firewall/iptables/router_linux.go | 340 +++++ client/firewall/iptables/router_linux_test.go | 229 ++++ client/firewall/iptables/rule.go | 3 +- client/firewall/iptables/rulestore_linux.go | 50 + client/firewall/{ => manager}/firewall.go | 31 +- client/firewall/{ => manager}/port.go | 2 +- client/firewall/manager/routerpair.go | 18 + client/firewall/nftables/acl_linux.go | 1121 +++++++++++++++++ client/firewall/nftables/ipsetstore_linux.go | 85 ++ client/firewall/nftables/manager_linux.go | 815 ++---------- .../firewall/nftables/manager_linux_test.go | 28 +- client/firewall/nftables/route_linux.go | 413 ++++++ client/firewall/nftables/router_linux_test.go | 280 ++++ client/firewall/nftables/rule_linux.go | 7 +- client/firewall/nftables/ruleset_linux.go | 115 -- .../firewall/nftables/ruleset_linux_test.go | 122 -- client/firewall/test/cases_linux.go | 47 + client/firewall/uspfilter/allow_netbird.go | 8 +- .../firewall/uspfilter/allow_netbird_linux.go | 21 - client/firewall/uspfilter/rule.go | 4 +- client/firewall/uspfilter/uspfilter.go | 95 +- client/firewall/uspfilter/uspfilter_test.go | 34 +- client/internal/acl/manager.go | 142 +-- client/internal/acl/manager_create.go | 28 - client/internal/acl/manager_create_linux.go | 77 -- client/internal/acl/manager_test.go | 21 +- client/internal/checkfw/check.go | 3 - client/internal/checkfw/check_linux.go | 56 - client/internal/dns/host_linux.go | 20 +- client/internal/engine.go | 29 +- .../routemanager/common_linux_test.go | 75 -- client/internal/routemanager/firewall.go | 12 - .../internal/routemanager/firewall_linux.go | 55 - .../routemanager/firewall_nonlinux.go | 15 - .../internal/routemanager/iptables_linux.go | 487 ------- .../routemanager/iptables_linux_test.go | 294 ----- client/internal/routemanager/manager.go | 18 +- client/internal/routemanager/mock.go | 5 + .../internal/routemanager/nftables_linux.go | 571 --------- .../routemanager/nftables_linux_test.go | 324 ----- client/internal/routemanager/router_pair.go | 24 - .../internal/routemanager/server_android.go | 3 +- .../routemanager/server_nonandroid.go | 38 +- .../routemanager/systemops_nonandroid_test.go | 2 + 50 files changed, 3720 insertions(+), 3627 deletions(-) create mode 100644 client/firewall/create.go create mode 100644 client/firewall/create_linux.go create mode 100644 client/firewall/iface.go create mode 100644 client/firewall/iptables/acl_linux.go create mode 100644 client/firewall/iptables/router_linux.go create mode 100644 client/firewall/iptables/router_linux_test.go create mode 100644 client/firewall/iptables/rulestore_linux.go rename client/firewall/{ => manager}/firewall.go (70%) rename client/firewall/{ => manager}/port.go (98%) create mode 100644 client/firewall/manager/routerpair.go create mode 100644 client/firewall/nftables/acl_linux.go create mode 100644 client/firewall/nftables/ipsetstore_linux.go create mode 100644 client/firewall/nftables/route_linux.go create mode 100644 client/firewall/nftables/router_linux_test.go delete mode 100644 client/firewall/nftables/ruleset_linux.go delete mode 100644 client/firewall/nftables/ruleset_linux_test.go create mode 100644 client/firewall/test/cases_linux.go delete mode 100644 client/firewall/uspfilter/allow_netbird_linux.go delete mode 100644 client/internal/acl/manager_create.go delete mode 100644 client/internal/acl/manager_create_linux.go delete mode 100644 client/internal/checkfw/check.go delete mode 100644 client/internal/checkfw/check_linux.go delete mode 100644 client/internal/routemanager/common_linux_test.go delete mode 100644 client/internal/routemanager/firewall.go delete mode 100644 client/internal/routemanager/firewall_linux.go delete mode 100644 client/internal/routemanager/firewall_nonlinux.go delete mode 100644 client/internal/routemanager/iptables_linux.go delete mode 100644 client/internal/routemanager/iptables_linux_test.go delete mode 100644 client/internal/routemanager/nftables_linux.go delete mode 100644 client/internal/routemanager/nftables_linux_test.go delete mode 100644 client/internal/routemanager/router_pair.go diff --git a/client/firewall/create.go b/client/firewall/create.go new file mode 100644 index 000000000..86ce94cea --- /dev/null +++ b/client/firewall/create.go @@ -0,0 +1,32 @@ +//go:build !linux || android + +package firewall + +import ( + "context" + "fmt" + "runtime" + + log "github.com/sirupsen/logrus" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter" +) + +// NewFirewall creates a firewall manager instance +func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { + if !iface.IsUserspaceBind() { + return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) + } + + // use userspace packet filtering firewall + fm, err := uspfilter.Create(iface) + if err != nil { + return nil, err + } + err = fm.AllowNetbird() + if err != nil { + log.Warnf("failed to allow netbird interface traffic: %v", err) + } + return fm, nil +} diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go new file mode 100644 index 000000000..a872e11c4 --- /dev/null +++ b/client/firewall/create_linux.go @@ -0,0 +1,107 @@ +//go:build !android + +package firewall + +import ( + "context" + "fmt" + "os" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/nftables" + log "github.com/sirupsen/logrus" + + nbiptables "github.com/netbirdio/netbird/client/firewall/iptables" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" + "github.com/netbirdio/netbird/client/firewall/uspfilter" +) + +const ( + // UNKNOWN is the default value for the firewall type for unknown firewall type + UNKNOWN FWType = iota + // IPTABLES is the value for the iptables firewall type + IPTABLES + // NFTABLES is the value for the nftables firewall type + NFTABLES +) + +// SKIP_NFTABLES_ENV is the environment variable to skip nftables check +const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" + +// FWType is the type for the firewall type +type FWType int + +func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { + // on the linux system we try to user nftables or iptables + // in any case, because we need to allow netbird interface traffic + // so we use AllowNetbird traffic from these firewall managers + // for the userspace packet filtering firewall + var fm firewall.Manager + var errFw error + + switch check() { + case IPTABLES: + log.Debug("creating an iptables firewall manager") + fm, errFw = nbiptables.Create(context, iface) + if errFw != nil { + log.Errorf("failed to create iptables manager: %s", errFw) + } + case NFTABLES: + log.Debug("creating an nftables firewall manager") + fm, errFw = nbnftables.Create(context, iface) + if errFw != nil { + log.Errorf("failed to create nftables manager: %s", errFw) + } + default: + errFw = fmt.Errorf("no firewall manager found") + log.Debug("no firewall manager found, try to use userspace packet filtering firewall") + } + + if iface.IsUserspaceBind() { + var errUsp error + if errFw == nil { + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + } else { + fm, errUsp = uspfilter.Create(iface) + } + if errUsp != nil { + log.Debugf("failed to create userspace filtering firewall: %s", errUsp) + return nil, errUsp + } + + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } + return fm, nil + } + + if errFw != nil { + return nil, errFw + } + + return fm, nil +} + +// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. +func check() FWType { + nf := nftables.Conn{} + if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { + return NFTABLES + } + + ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return UNKNOWN + } + if isIptablesClientAvailable(ip) { + return IPTABLES + } + + return UNKNOWN +} + +func isIptablesClientAvailable(client *iptables.IPTables) bool { + _, err := client.ListChains("filter") + return err == nil +} diff --git a/client/firewall/iface.go b/client/firewall/iface.go new file mode 100644 index 000000000..882daef75 --- /dev/null +++ b/client/firewall/iface.go @@ -0,0 +1,11 @@ +package firewall + +import "github.com/netbirdio/netbird/iface" + +// IFaceMapper defines subset methods of interface required for manager +type IFaceMapper interface { + Name() string + Address() iface.WGAddress + IsUserspaceBind() bool + SetFilter(iface.PacketFilter) error +} diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go new file mode 100644 index 000000000..b77cc8f43 --- /dev/null +++ b/client/firewall/iptables/acl_linux.go @@ -0,0 +1,473 @@ +package iptables + +import ( + "fmt" + "net" + "strconv" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/uuid" + "github.com/nadoo/ipset" + log "github.com/sirupsen/logrus" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + tableName = "filter" + + // rules chains contains the effective ACL rules + chainNameInputRules = "NETBIRD-ACL-INPUT" + chainNameOutputRules = "NETBIRD-ACL-OUTPUT" + + postRoutingMark = "0x000007e4" +) + +type aclManager struct { + iptablesClient *iptables.IPTables + wgIface iFaceMapper + routeingFwChainName string + + entries map[string][][]string + ipsetStore *ipsetStore +} + +func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) { + m := &aclManager{ + iptablesClient: iptablesClient, + wgIface: wgIface, + routeingFwChainName: routeingFwChainName, + + entries: make(map[string][][]string), + ipsetStore: newIpsetStore(), + } + + err := ipset.Init() + if err != nil { + return nil, fmt.Errorf("failed to init ipset: %w", err) + } + + m.seedInitialEntries() + + err = m.cleanChains() + if err != nil { + return nil, err + } + + err = m.createDefaultChains() + if err != nil { + return nil, err + } + return m, nil +} + +func (m *aclManager) AddFiltering( + ip net.IP, + protocol firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + direction firewall.RuleDirection, + action firewall.Action, + ipsetName string, +) ([]firewall.Rule, error) { + var dPortVal, sPortVal string + if dPort != nil && dPort.Values != nil { + // TODO: we support only one port per rule in current implementation of ACLs + dPortVal = strconv.Itoa(dPort.Values[0]) + } + if sPort != nil && sPort.Values != nil { + sPortVal = strconv.Itoa(sPort.Values[0]) + } + + var chain string + if direction == firewall.RuleDirectionOUT { + chain = chainNameOutputRules + } else { + chain = chainNameInputRules + } + + ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal) + specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) + if ipsetName != "" { + if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { + if err := ipset.Add(ipsetName, ip.String()); err != nil { + return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + } + // if ruleset already exists it means we already have the firewall rule + // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. + ipList.addIP(ip.String()) + return []firewall.Rule{&Rule{ + ruleID: uuid.New().String(), + ipsetName: ipsetName, + ip: ip.String(), + chain: chain, + specs: specs, + }}, nil + } + + if err := ipset.Flush(ipsetName); err != nil { + log.Errorf("flush ipset %s before use it: %s", ipsetName, err) + } + if err := ipset.Create(ipsetName); err != nil { + return nil, fmt.Errorf("failed to create ipset: %w", err) + } + if err := ipset.Add(ipsetName, ip.String()); err != nil { + return nil, fmt.Errorf("failed to add IP to ipset: %w", err) + } + + ipList := newIpList(ip.String()) + m.ipsetStore.addIpList(ipsetName, ipList) + } + + ok, err := m.iptablesClient.Exists("filter", chain, specs...) + if err != nil { + return nil, fmt.Errorf("failed to check rule: %w", err) + } + if ok { + return nil, fmt.Errorf("rule already exists") + } + + if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil { + return nil, err + } + + rule := &Rule{ + ruleID: uuid.New().String(), + specs: specs, + ipsetName: ipsetName, + ip: ip.String(), + chain: chain, + } + + if !shouldAddToPrerouting(protocol, dPort, direction) { + return []firewall.Rule{rule}, nil + } + + rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip) + if err != nil { + return []firewall.Rule{rule}, err + } + return []firewall.Rule{rule, rulePrerouting}, nil +} + +// DeleteRule from the firewall by rule definition +func (m *aclManager) DeleteRule(rule firewall.Rule) error { + r, ok := rule.(*Rule) + if !ok { + return fmt.Errorf("invalid rule type") + } + + if r.chain == "PREROUTING" { + goto DELETERULE + } + + if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { + // delete IP from ruleset IPs list and ipset + if _, ok := ipsetList.ips[r.ip]; ok { + if err := ipset.Del(r.ipsetName, r.ip); err != nil { + return fmt.Errorf("failed to delete ip from ipset: %w", err) + } + delete(ipsetList.ips, r.ip) + } + + // if after delete, set still contains other IPs, + // no need to delete firewall rule and we should exit here + if len(ipsetList.ips) != 0 { + return nil + } + + // we delete last IP from the set, that means we need to delete + // set itself and associated firewall rule too + m.ipsetStore.deleteIpset(r.ipsetName) + + if err := ipset.Destroy(r.ipsetName); err != nil { + log.Errorf("delete empty ipset: %v", err) + } + } + +DELETERULE: + var table string + if r.chain == "PREROUTING" { + table = "mangle" + } else { + table = "filter" + } + err := m.iptablesClient.Delete(table, r.chain, r.specs...) + if err != nil { + log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) + } + return err +} + +func (m *aclManager) Reset() error { + return m.cleanChains() +} + +func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) { + var src []string + if ipsetName != "" { + src = []string{"-m", "set", "--set", ipsetName, "src"} + } else { + src = []string{"-s", ip.String()} + } + specs := []string{ + "-d", m.wgIface.Address().IP.String(), + "-p", protocol, + "--dport", port, + "-j", "MARK", "--set-mark", postRoutingMark, + } + + specs = append(src, specs...) + + ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...) + if err != nil { + return nil, fmt.Errorf("failed to check rule: %w", err) + } + if ok { + return nil, fmt.Errorf("rule already exists") + } + + if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil { + return nil, err + } + + rule := &Rule{ + ruleID: uuid.New().String(), + specs: specs, + ipsetName: ipsetName, + ip: ip.String(), + chain: "PREROUTING", + } + return rule, nil +} + +// todo write less destructive cleanup mechanism +func (m *aclManager) cleanChains() error { + ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) + if err != nil { + log.Debugf("failed to list chains: %s", err) + return err + } + if ok { + rules := m.entries["OUTPUT"] + for _, rule := range rules { + err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + + err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules) + if err != nil { + log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err) + return err + } + } + + ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules) + if err != nil { + log.Debugf("failed to list chains: %s", err) + return err + } + if ok { + for _, rule := range m.entries["INPUT"] { + err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + + for _, rule := range m.entries["FORWARD"] { + err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + + err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules) + if err != nil { + log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err) + return err + } + } + + ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") + if err != nil { + log.Debugf("failed to list chains: %s", err) + return err + } + if ok { + for _, rule := range m.entries["PREROUTING"] { + err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + err = m.iptablesClient.ClearChain("mangle", "PREROUTING") + if err != nil { + log.Debugf("failed to clear %s chain: %s", "PREROUTING", err) + return err + } + } + + for _, ipsetName := range m.ipsetStore.ipsetNames() { + if err := ipset.Flush(ipsetName); err != nil { + log.Errorf("flush ipset %q during reset: %v", ipsetName, err) + } + if err := ipset.Destroy(ipsetName); err != nil { + log.Errorf("delete ipset %q during reset: %v", ipsetName, err) + } + m.ipsetStore.deleteIpset(ipsetName) + } + + return nil +} + +func (m *aclManager) createDefaultChains() error { + // chain netbird-acl-input-rules + if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil { + log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err) + return err + } + + // chain netbird-acl-output-rules + if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil { + log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err) + return err + } + + for chainName, rules := range m.entries { + for _, rule := range rules { + if chainName == "FORWARD" { + // position 2 because we add it after router's, jump rule + if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil { + log.Debugf("failed to create input chain jump rule: %s", err) + return err + } + } else { + if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil { + log.Debugf("failed to create input chain jump rule: %s", err) + return err + } + } + } + } + + return nil +} + +func (m *aclManager) seedInitialEntries() { + m.appendToEntries("INPUT", + []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + + m.appendToEntries("INPUT", + []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + + m.appendToEntries("INPUT", + []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules}) + + m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) + + m.appendToEntries("OUTPUT", + []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + + m.appendToEntries("OUTPUT", + []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + + m.appendToEntries("OUTPUT", + []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules}) + + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) + + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) + m.appendToEntries("FORWARD", + []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) + m.appendToEntries("FORWARD", + []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) + m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName}) + + m.appendToEntries("PREROUTING", + []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark}) +} + +func (m *aclManager) appendToEntries(chainName string, spec []string) { + m.entries[chainName] = append(m.entries[chainName], spec) +} + +// filterRuleSpecs returns the specs of a filtering rule +func filterRuleSpecs( + ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, +) (specs []string) { + matchByIP := true + // don't use IP matching if IP is ip 0.0.0.0 + if ip.String() == "0.0.0.0" { + matchByIP = false + } + switch direction { + case firewall.RuleDirectionIN: + if matchByIP { + if ipsetName != "" { + specs = append(specs, "-m", "set", "--set", ipsetName, "src") + } else { + specs = append(specs, "-s", ip.String()) + } + } + case firewall.RuleDirectionOUT: + if matchByIP { + if ipsetName != "" { + specs = append(specs, "-m", "set", "--set", ipsetName, "dst") + } else { + specs = append(specs, "-d", ip.String()) + } + } + } + if protocol != "all" { + specs = append(specs, "-p", protocol) + } + if sPort != "" { + specs = append(specs, "--sport", sPort) + } + if dPort != "" { + specs = append(specs, "--dport", dPort) + } + return append(specs, "-j", actionToStr(action)) +} + +func actionToStr(action firewall.Action) string { + if action == firewall.ActionAccept { + return "ACCEPT" + } + return "DROP" +} + +func transformIPsetName(ipsetName string, sPort, dPort string) string { + switch { + case ipsetName == "": + return "" + case sPort != "" && dPort != "": + return ipsetName + "-sport-dport" + case sPort != "": + return ipsetName + "-sport" + case dPort != "": + return ipsetName + "-dport" + default: + return ipsetName + } +} + +func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { + if proto == "all" { + return false + } + + if direction != firewall.RuleDirectionIN { + return false + } + + if dPort == nil { + return false + } + return true +} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index b9243f4ca..2d231ec45 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -1,43 +1,27 @@ package iptables import ( + "context" "fmt" "net" - "strconv" "sync" "github.com/coreos/go-iptables/iptables" - "github.com/google/uuid" - "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" - fw "github.com/netbirdio/netbird/client/firewall" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" ) -const ( - // ChainInputFilterName is the name of the chain that is used for filtering incoming packets - ChainInputFilterName = "NETBIRD-ACL-INPUT" - - // ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets - ChainOutputFilterName = "NETBIRD-ACL-OUTPUT" -) - -// dropAllDefaultRule in the Netbird chain -var dropAllDefaultRule = []string{"-j", "DROP"} - // Manager of iptables firewall type Manager struct { mutex sync.Mutex + wgIface iFaceMapper + ipv4Client *iptables.IPTables - ipv6Client *iptables.IPTables - - inputDefaultRuleSpecs []string - outputDefaultRuleSpecs []string - wgIface iFaceMapper - - rulesets map[string]ruleset + aclMgr *aclManager + router *routerManager } // iFaceMapper defines subset methods of interface required for manager @@ -47,47 +31,29 @@ type iFaceMapper interface { IsUserspaceBind() bool } -type ruleset struct { - rule *Rule - ips map[string]string -} - // Create iptables firewall manager -func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) { - m := &Manager{ - wgIface: wgIface, - inputDefaultRuleSpecs: []string{ - "-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()}, - outputDefaultRuleSpecs: []string{ - "-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()}, - rulesets: make(map[string]ruleset), - } - - err := ipset.Init() - if err != nil { - return nil, fmt.Errorf("init ipset: %w", err) - } - - // init clients for booth ipv4 and ipv6 - m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4) +func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("iptables is not installed in the system or not supported") } - if ipv6Supported { - m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err) - } + m := &Manager{ + wgIface: wgIface, + ipv4Client: iptablesClient, } - if m.ipv4Client == nil && m.ipv6Client == nil { - return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it") + m.router, err = newRouterManager(context, iptablesClient) + if err != nil { + log.Debugf("failed to initialize route related chains: %s", err) + return nil, err + } + m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName()) + if err != nil { + log.Debugf("failed to initialize ACL manager: %s", err) + return nil, err } - if err := m.Reset(); err != nil { - return nil, fmt.Errorf("failed to reset firewall: %v", err) - } return m, nil } @@ -96,159 +62,44 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) { // Comment will be ignored because some system this feature is not supported func (m *Manager) AddFiltering( ip net.IP, - protocol fw.Protocol, - sPort *fw.Port, - dPort *fw.Port, - direction fw.RuleDirection, - action fw.Action, + protocol firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + direction firewall.RuleDirection, + action firewall.Action, ipsetName string, comment string, -) (fw.Rule, error) { +) ([]firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - client, err := m.client(ip) - if err != nil { - return nil, err - } - - var dPortVal, sPortVal string - if dPort != nil && dPort.Values != nil { - // TODO: we support only one port per rule in current implementation of ACLs - dPortVal = strconv.Itoa(dPort.Values[0]) - } - if sPort != nil && sPort.Values != nil { - sPortVal = strconv.Itoa(sPort.Values[0]) - } - ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal) - - ruleID := uuid.New().String() - - if ipsetName != "" { - rs, rsExists := m.rulesets[ipsetName] - if !rsExists { - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %q before use it: %v", ipsetName, err) - } - if err := ipset.Create(ipsetName); err != nil { - return nil, fmt.Errorf("failed to create ipset: %w", err) - } - } - - if err := ipset.Add(ipsetName, ip.String()); err != nil { - return nil, fmt.Errorf("failed to add IP to ipset: %w", err) - } - - if rsExists { - // if ruleset already exists it means we already have the firewall rule - // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. - rs.ips[ip.String()] = ruleID - return &Rule{ - ruleID: ruleID, - ipsetName: ipsetName, - ip: ip.String(), - dst: direction == fw.RuleDirectionOUT, - v6: ip.To4() == nil, - }, nil - } - // this is new ipset so we need to create firewall rule for it - } - - specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName) - - if direction == fw.RuleDirectionOUT { - ok, err := client.Exists("filter", ChainOutputFilterName, specs...) - if err != nil { - return nil, fmt.Errorf("check is output rule already exists: %w", err) - } - if ok { - return nil, fmt.Errorf("input rule already exists") - } - - if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil { - return nil, err - } - } else { - ok, err := client.Exists("filter", ChainInputFilterName, specs...) - if err != nil { - return nil, fmt.Errorf("check is input rule already exists: %w", err) - } - if ok { - return nil, fmt.Errorf("input rule already exists") - } - - if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil { - return nil, err - } - } - - rule := &Rule{ - ruleID: ruleID, - specs: specs, - ipsetName: ipsetName, - ip: ip.String(), - dst: direction == fw.RuleDirectionOUT, - v6: ip.To4() == nil, - } - if ipsetName != "" { - // ipset name is defined and it means that this rule was created - // for it, need to associate it with ruleset - m.rulesets[ipsetName] = ruleset{ - rule: rule, - ips: map[string]string{rule.ip: ruleID}, - } - } - - return rule, nil + return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) } // DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule fw.Rule) error { +func (m *Manager) DeleteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - r, ok := rule.(*Rule) - if !ok { - return fmt.Errorf("invalid rule type") - } + return m.aclMgr.DeleteRule(rule) +} - client := m.ipv4Client - if r.v6 { - if m.ipv6Client == nil { - return fmt.Errorf("ipv6 is not supported") - } - client = m.ipv6Client - } +func (m *Manager) IsServerRouteSupported() bool { + return true +} - if rs, ok := m.rulesets[r.ipsetName]; ok { - // delete IP from ruleset IPs list and ipset - if _, ok := rs.ips[r.ip]; ok { - if err := ipset.Del(r.ipsetName, r.ip); err != nil { - return fmt.Errorf("failed to delete ip from ipset: %w", err) - } - delete(rs.ips, r.ip) - } +func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { + m.mutex.Lock() + defer m.mutex.Unlock() - // if after delete, set still contains other IPs, - // no need to delete firewall rule and we should exit here - if len(rs.ips) != 0 { - return nil - } + return m.router.InsertRoutingRules(pair) +} - // we delete last IP from the set, that means we need to delete - // set itself and associated firewall rule too - delete(m.rulesets, r.ipsetName) +func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { + m.mutex.Lock() + defer m.mutex.Unlock() - if err := ipset.Destroy(r.ipsetName); err != nil { - log.Errorf("delete empty ipset: %v", err) - } - r = rs.rule - } - - if r.dst { - return client.Delete("filter", ChainOutputFilterName, r.specs...) - } - return client.Delete("filter", ChainInputFilterName, r.specs...) + return m.router.RemoveRoutingRules(pair) } // Reset firewall to the default state @@ -256,223 +107,49 @@ func (m *Manager) Reset() error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.reset(m.ipv4Client, "filter"); err != nil { - return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err) + errAcl := m.aclMgr.Reset() + if errAcl != nil { + log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) } - if m.ipv6Client != nil { - if err := m.reset(m.ipv6Client, "filter"); err != nil { - return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err) - } + errMgr := m.router.Reset() + if errMgr != nil { + log.Errorf("failed to clean up router rules from firewall: %s", errMgr) + return errMgr } - - return nil + return errAcl } // AllowNetbird allows netbird interface traffic func (m *Manager) AllowNetbird() error { - if m.wgIface.IsUserspaceBind() { - _, err := m.AddFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - fw.RuleDirectionIN, - fw.ActionAccept, - "", - "", - ) - if err != nil { - return fmt.Errorf("failed to allow netbird interface traffic: %w", err) - } - _, err = m.AddFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - fw.RuleDirectionOUT, - fw.ActionAccept, - "", - "", - ) - return err + if !m.wgIface.IsUserspaceBind() { + return nil } - return nil + _, err := m.AddFiltering( + net.ParseIP("0.0.0.0"), + "all", + nil, + nil, + firewall.RuleDirectionIN, + firewall.ActionAccept, + "", + "", + ) + if err != nil { + return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + } + _, err = m.AddFiltering( + net.ParseIP("0.0.0.0"), + "all", + nil, + nil, + firewall.RuleDirectionOUT, + firewall.ActionAccept, + "", + "", + ) + return err } // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } - -// reset firewall chain, clear it and drop it -func (m *Manager) reset(client *iptables.IPTables, table string) error { - ok, err := client.ChainExists(table, ChainInputFilterName) - if err != nil { - return fmt.Errorf("failed to check if input chain exists: %w", err) - } - if ok { - if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { - return err - } else if ok { - if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil { - log.WithError(err).Errorf("failed to delete default input rule: %v", err) - } - } - } - - ok, err = client.ChainExists(table, ChainOutputFilterName) - if err != nil { - return fmt.Errorf("failed to check if output chain exists: %w", err) - } - if ok { - if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil { - return err - } else if ok { - if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil { - log.WithError(err).Errorf("failed to delete default output rule: %v", err) - } - } - } - - if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil { - log.Errorf("failed to clear and delete input chain: %v", err) - return nil - } - - if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil { - log.Errorf("failed to clear and delete input chain: %v", err) - return nil - } - - for ipsetName := range m.rulesets { - if err := ipset.Flush(ipsetName); err != nil { - log.Errorf("flush ipset %q during reset: %v", ipsetName, err) - } - if err := ipset.Destroy(ipsetName); err != nil { - log.Errorf("delete ipset %q during reset: %v", ipsetName, err) - } - delete(m.rulesets, ipsetName) - } - - return nil -} - -// filterRuleSpecs returns the specs of a filtering rule -func (m *Manager) filterRuleSpecs( - ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string, -) (specs []string) { - matchByIP := true - // don't use IP matching if IP is ip 0.0.0.0 - if s := ip.String(); s == "0.0.0.0" || s == "::" { - matchByIP = false - } - switch direction { - case fw.RuleDirectionIN: - if matchByIP { - if ipsetName != "" { - specs = append(specs, "-m", "set", "--set", ipsetName, "src") - } else { - specs = append(specs, "-s", ip.String()) - } - } - case fw.RuleDirectionOUT: - if matchByIP { - if ipsetName != "" { - specs = append(specs, "-m", "set", "--set", ipsetName, "dst") - } else { - specs = append(specs, "-d", ip.String()) - } - } - } - if protocol != "all" { - specs = append(specs, "-p", protocol) - } - if sPort != "" { - specs = append(specs, "--sport", sPort) - } - if dPort != "" { - specs = append(specs, "--dport", dPort) - } - return append(specs, "-j", m.actionToStr(action)) -} - -// rawClient returns corresponding iptables client for the given ip -func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) { - if ip.To4() != nil { - return m.ipv4Client, nil - } - if m.ipv6Client == nil { - return nil, fmt.Errorf("ipv6 is not supported") - } - return m.ipv6Client, nil -} - -// client returns client with initialized chain and default rules -func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) { - client, err := m.rawClient(ip) - if err != nil { - return nil, err - } - - ok, err := client.ChainExists("filter", ChainInputFilterName) - if err != nil { - return nil, fmt.Errorf("failed to check if chain exists: %w", err) - } - - if !ok { - if err := client.NewChain("filter", ChainInputFilterName); err != nil { - return nil, fmt.Errorf("failed to create input chain: %w", err) - } - - if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil { - return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err) - } - - if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil { - return nil, fmt.Errorf("failed to create input chain jump rule: %w", err) - } - - } - - ok, err = client.ChainExists("filter", ChainOutputFilterName) - if err != nil { - return nil, fmt.Errorf("failed to check if chain exists: %w", err) - } - - if !ok { - if err := client.NewChain("filter", ChainOutputFilterName); err != nil { - return nil, fmt.Errorf("failed to create output chain: %w", err) - } - - if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil { - return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err) - } - - if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil { - return nil, fmt.Errorf("failed to create output chain jump rule: %w", err) - } - } - - return client, nil -} - -func (m *Manager) actionToStr(action fw.Action) string { - if action == fw.ActionAccept { - return "ACCEPT" - } - return "DROP" -} - -func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string { - switch { - case ipsetName == "": - return "" - case sPort != "" && dPort != "": - return ipsetName + "-sport-dport" - case sPort != "": - return ipsetName + "-sport" - case dPort != "": - return ipsetName + "-dport" - default: - return ipsetName - } -} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 90375d3e2..ceb116c62 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -1,6 +1,7 @@ package iptables import ( + "context" "fmt" "net" "testing" @@ -9,7 +10,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/stretchr/testify/require" - fw "github.com/netbirdio/netbird/client/firewall" + fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" ) @@ -55,7 +56,7 @@ func TestIptablesManager(t *testing.T) { } // just check on the local interface - manager, err := Create(mock, true) + manager, err := Create(context.Background(), mock) require.NoError(t, err) time.Sleep(time.Second) @@ -67,17 +68,20 @@ func TestIptablesManager(t *testing.T) { time.Sleep(time.Second) }() - var rule1 fw.Rule + var rule1 []fw.Rule t.Run("add first rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") - checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) + for _, r := range rule1 { + checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...) + } + }) - var rule2 fw.Rule + var rule2 []fw.Rule t.Run("add second rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.3") port := &fw.Port{ @@ -87,21 +91,28 @@ func TestIptablesManager(t *testing.T) { ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") - checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...) + for _, r := range rule2 { + rr := r.(*Rule) + checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...) + } }) t.Run("delete first rule", func(t *testing.T) { - err := manager.DeleteRule(rule1) - require.NoError(t, err, "failed to delete rule") + for _, r := range rule1 { + err := manager.DeleteRule(r) + require.NoError(t, err, "failed to delete rule") - checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...) + checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) + } }) t.Run("delete second rule", func(t *testing.T) { - err := manager.DeleteRule(rule2) - require.NoError(t, err, "failed to delete rule") + for _, r := range rule2 { + err := manager.DeleteRule(r) + require.NoError(t, err, "failed to delete rule") + } - require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty") + require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") }) t.Run("reset check", func(t *testing.T) { @@ -114,11 +125,11 @@ func TestIptablesManager(t *testing.T) { err = manager.Reset() require.NoError(t, err, "failed to reset") - ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName) + ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) require.NoError(t, err, "failed check chain exists") if ok { - require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName) + require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules) } }) } @@ -143,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock, true) + manager, err := Create(context.Background(), mock) require.NoError(t, err) time.Sleep(time.Second) @@ -155,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) { time.Sleep(time.Second) }() - var rule1 fw.Rule + var rule1 []fw.Rule t.Run("add first rule with set", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} @@ -165,12 +176,14 @@ func TestIptablesManagerIPSet(t *testing.T) { ) require.NoError(t, err, "failed to add rule") - checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) - require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set") - require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set") + for _, r := range rule1 { + checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...) + require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set") + require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set") + } }) - var rule2 fw.Rule + var rule2 []fw.Rule t.Run("add second rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.3") port := &fw.Port{ @@ -180,23 +193,29 @@ func TestIptablesManagerIPSet(t *testing.T) { ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "default", "accept HTTPS traffic from ports range", ) - require.NoError(t, err, "failed to add rule") - require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set") - require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set") + for _, r := range rule2 { + require.NoError(t, err, "failed to add rule") + require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") + require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set") + } }) t.Run("delete first rule", func(t *testing.T) { - err := manager.DeleteRule(rule1) - require.NoError(t, err, "failed to delete rule") + for _, r := range rule1 { + err := manager.DeleteRule(r) + require.NoError(t, err, "failed to delete rule") - require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index") + require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") + } }) t.Run("delete second rule", func(t *testing.T) { - err := manager.DeleteRule(rule2) - require.NoError(t, err, "failed to delete rule") + for _, r := range rule2 { + err := manager.DeleteRule(r) + require.NoError(t, err, "failed to delete rule") - require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty") + require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") + } }) t.Run("reset check", func(t *testing.T) { @@ -206,7 +225,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) { - t.Helper() + t.Helper() exists, err := ipv4Client.Exists("filter", chainName, rulespec...) require.NoError(t, err, "failed to check rule") require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec) @@ -232,7 +251,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock, true) + manager, err := Create(context.Background(), mock) require.NoError(t, err) time.Sleep(time.Second) @@ -243,7 +262,6 @@ func TestIptablesCreatePerformance(t *testing.T) { time.Sleep(time.Second) }() - _, err = manager.client(net.ParseIP("10.20.0.100")) require.NoError(t, err) ip := net.ParseIP("10.20.0.100") diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go new file mode 100644 index 000000000..fc395c9cf --- /dev/null +++ b/client/firewall/iptables/router_linux.go @@ -0,0 +1,340 @@ +//go:build !android + +package iptables + +import ( + "context" + "fmt" + "strings" + + "github.com/coreos/go-iptables/iptables" + log "github.com/sirupsen/logrus" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + Ipv4Forwarding = "netbird-rt-forwarding" + ipv4Nat = "netbird-rt-nat" +) + +// constants needed to manage and create iptable rules +const ( + tableFilter = "filter" + tableNat = "nat" + chainFORWARD = "FORWARD" + chainPOSTROUTING = "POSTROUTING" + chainRTNAT = "NETBIRD-RT-NAT" + chainRTFWD = "NETBIRD-RT-FWD" + routingFinalForwardJump = "ACCEPT" + routingFinalNatJump = "MASQUERADE" +) + +type routerManager struct { + ctx context.Context + stop context.CancelFunc + iptablesClient *iptables.IPTables + rules map[string][]string +} + +func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) { + ctx, cancel := context.WithCancel(parentCtx) + m := &routerManager{ + ctx: ctx, + stop: cancel, + iptablesClient: iptablesClient, + rules: make(map[string][]string), + } + + err := m.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to cleanup routing rules: %s", err) + return nil, err + } + err = m.createContainers() + if err != nil { + log.Errorf("failed to create containers for route: %s", err) + } + return m, err +} + +// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain +func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { + err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair) + if err != nil { + return err + } + + err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair)) + if err != nil { + return err + } + + if !pair.Masquerade { + return nil + } + + err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) + if err != nil { + return err + } + + err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) + if err != nil { + return err + } + + return nil +} + +// insertRoutingRule inserts an iptable rule +func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { + var err error + + ruleKey := firewall.GenKey(keyFormat, pair.ID) + rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination) + existingRule, found := i.rules[ruleKey] + if found { + err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + } + delete(i.rules, ruleKey) + } + err = i.iptablesClient.Insert(table, chain, 1, rule...) + if err != nil { + return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + } + + i.rules[ruleKey] = rule + + return nil +} + +// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains +func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error { + err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair) + if err != nil { + return err + } + + err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair)) + if err != nil { + return err + } + + if !pair.Masquerade { + return nil + } + + err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair) + if err != nil { + return err + } + + err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair)) + if err != nil { + return err + } + + return nil +} + +func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error { + var err error + + ruleKey := firewall.GenKey(keyFormat, pair.ID) + existingRule, found := i.rules[ruleKey] + if found { + err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + } + } + delete(i.rules, ruleKey) + + return nil +} + +func (i *routerManager) RouteingFwChainName() string { + return chainRTFWD +} + +func (i *routerManager) Reset() error { + err := i.cleanUpDefaultForwardRules() + if err != nil { + return err + } + i.rules = make(map[string][]string) + return nil +} + +func (i *routerManager) cleanUpDefaultForwardRules() error { + err := i.cleanJumpRules() + if err != nil { + return err + } + + log.Debug("flushing routing related tables") + ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD) + if err != nil { + log.Errorf("failed check chain %s,error: %v", chainRTFWD, err) + return err + } else if ok { + err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD) + if err != nil { + log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err) + return err + } + } + + ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT) + if err != nil { + log.Errorf("failed check chain %s,error: %v", chainRTNAT, err) + return err + } else if ok { + err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT) + if err != nil { + log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err) + return err + } + } + return nil +} + +func (i *routerManager) createContainers() error { + if i.rules[Ipv4Forwarding] != nil { + return nil + } + + errMSGFormat := "failed creating chain %s,error: %v" + err := i.createChain(tableFilter, chainRTFWD) + if err != nil { + return fmt.Errorf(errMSGFormat, chainRTFWD, err) + } + + err = i.createChain(tableNat, chainRTNAT) + if err != nil { + return fmt.Errorf(errMSGFormat, chainRTNAT, err) + } + + err = i.addJumpRules() + if err != nil { + return fmt.Errorf("error while creating jump rules: %v", err) + } + + return nil +} + +// addJumpRules create jump rules to send packets to NetBird chains +func (i *routerManager) addJumpRules() error { + rule := []string{"-j", chainRTFWD} + err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...) + if err != nil { + return err + } + i.rules[Ipv4Forwarding] = rule + + rule = []string{"-j", chainRTNAT} + err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) + if err != nil { + return err + } + i.rules[ipv4Nat] = rule + + return nil +} + +// cleanJumpRules cleans jump rules that was sending packets to NetBird chains +func (i *routerManager) cleanJumpRules() error { + var err error + errMSGFormat := "failed cleaning rule from chain %s,err: %v" + rule, found := i.rules[Ipv4Forwarding] + if found { + err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...) + if err != nil { + return fmt.Errorf(errMSGFormat, chainFORWARD, err) + } + } + rule, found = i.rules[ipv4Nat] + if found { + err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) + if err != nil { + return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err) + } + } + + rules, err := i.iptablesClient.List("nat", "POSTROUTING") + if err != nil { + return fmt.Errorf("failed to list rules: %s", err) + } + + for _, ruleString := range rules { + if !strings.Contains(ruleString, "NETBIRD") { + continue + } + rule := strings.Fields(ruleString) + err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...) + if err != nil { + return fmt.Errorf("failed to delete postrouting jump rule: %s", err) + } + } + + rules, err = i.iptablesClient.List(tableFilter, "FORWARD") + if err != nil { + return fmt.Errorf("failed to list rules in FORWARD chain: %s", err) + } + + for _, ruleString := range rules { + if !strings.Contains(ruleString, "NETBIRD") { + continue + } + rule := strings.Fields(ruleString) + err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...) + if err != nil { + return fmt.Errorf("failed to delete FORWARD jump rule: %s", err) + } + } + return nil +} + +func (i *routerManager) createChain(table, newChain string) error { + chains, err := i.iptablesClient.ListChains(table) + if err != nil { + return fmt.Errorf("couldn't get %s table chains, error: %v", table, err) + } + + shouldCreateChain := true + for _, chain := range chains { + if chain == newChain { + shouldCreateChain = false + } + } + + if shouldCreateChain { + err = i.iptablesClient.NewChain(table, newChain) + if err != nil { + return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) + } + + err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") + if err != nil { + return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) + } + + } + return nil +} + +// genRuleSpec generates rule specification with comment identifier +func genRuleSpec(jump, id, source, destination string) []string { + return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} +} + +func getIptablesRuleType(table string) string { + ruleType := "forwarding" + if table == tableNat { + ruleType = "nat" + } + return ruleType +} diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go new file mode 100644 index 000000000..b4b81a389 --- /dev/null +++ b/client/firewall/iptables/router_linux_test.go @@ -0,0 +1,229 @@ +//go:build !android + +package iptables + +import ( + "context" + "os/exec" + "testing" + + "github.com/coreos/go-iptables/iptables" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/test" +) + +func isIptablesSupported() bool { + _, err4 := exec.LookPath("iptables") + return err4 == nil +} + +func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { + if !isIptablesSupported() { + t.SkipNow() + } + + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err, "failed to init iptables client") + + manager, err := newRouterManager(context.TODO(), iptablesClient) + require.NoError(t, err, "should return a valid iptables manager") + + defer func() { + _ = manager.Reset() + }() + + require.Len(t, manager.rules, 2, "should have created rules map") + + exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD) + require.True(t, exists, "forwarding rule should exist") + + exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) + require.True(t, exists, "postrouting rule should exist") + + pair := firewall.RouterPair{ + ID: "abc", + Source: "100.100.100.1/32", + Destination: "100.100.100.0/24", + Masquerade: true, + } + forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID) + forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination) + + err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) + require.NoError(t, err, "inserting rule should not return error") + + nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID) + nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination) + + err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) + require.NoError(t, err, "inserting rule should not return error") + + err = manager.Reset() + require.NoError(t, err, "shouldn't return error") +} + +func TestIptablesManager_InsertRoutingRules(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + for _, testCase := range test.InsertRuleTestCases { + t.Run(testCase.Name, func(t *testing.T) { + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err, "failed to init iptables client") + + manager, err := newRouterManager(context.TODO(), iptablesClient) + require.NoError(t, err, "shouldn't return error") + + defer func() { + err := manager.Reset() + if err != nil { + log.Errorf("failed to reset iptables manager: %s", err) + } + }() + + err = manager.InsertRoutingRules(testCase.InputPair) + require.NoError(t, err, "forwarding pair should be inserted") + + forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination) + + exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) + require.True(t, exists, "forwarding rule should exist") + + foundRule, found := manager.rules[forwardRuleKey] + require.True(t, found, "forwarding rule should exist in the manager map") + require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") + + inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) + inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + + exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) + require.True(t, exists, "income forwarding rule should exist") + + foundRule, found = manager.rules[inForwardRuleKey] + require.True(t, found, "income forwarding rule should exist in the manager map") + require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") + + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination) + + exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) + if testCase.InputPair.Masquerade { + require.True(t, exists, "nat rule should be created") + foundNatRule, foundNat := manager.rules[natRuleKey] + require.True(t, foundNat, "nat rule should exist in the map") + require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match") + } else { + require.False(t, exists, "nat rule should not be created") + _, foundNat := manager.rules[natRuleKey] + require.False(t, foundNat, "nat rule should not exist in the map") + } + + inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + + exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) + if testCase.InputPair.Masquerade { + require.True(t, exists, "income nat rule should be created") + foundNatRule, foundNat := manager.rules[inNatRuleKey] + require.True(t, foundNat, "income nat rule should exist in the map") + require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") + } else { + require.False(t, exists, "nat rule should not be created") + _, foundNat := manager.rules[inNatRuleKey] + require.False(t, foundNat, "income nat rule should not exist in the map") + } + }) + } +} + +func TestIptablesManager_RemoveRoutingRules(t *testing.T) { + + if !isIptablesSupported() { + t.SkipNow() + } + + for _, testCase := range test.RemoveRuleTestCases { + t.Run(testCase.Name, func(t *testing.T) { + iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + + manager, err := newRouterManager(context.TODO(), iptablesClient) + require.NoError(t, err, "shouldn't return error") + defer func() { + _ = manager.Reset() + }() + + require.NoError(t, err, "shouldn't return error") + + forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) + forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination) + + err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) + require.NoError(t, err, "inserting rule should not return error") + + inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) + inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + + err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) + require.NoError(t, err, "inserting rule should not return error") + + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination) + + err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) + require.NoError(t, err, "inserting rule should not return error") + + inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + + err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) + require.NoError(t, err, "inserting rule should not return error") + + err = manager.Reset() + require.NoError(t, err, "shouldn't return error") + + err = manager.RemoveRoutingRules(testCase.InputPair) + require.NoError(t, err, "shouldn't return error") + + exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) + require.False(t, exists, "forwarding rule should not exist") + + _, found := manager.rules[forwardRuleKey] + require.False(t, found, "forwarding rule should exist in the manager map") + + exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) + require.False(t, exists, "income forwarding rule should not exist") + + _, found = manager.rules[inForwardRuleKey] + require.False(t, found, "income forwarding rule should exist in the manager map") + + exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) + require.False(t, exists, "nat rule should not exist") + + _, found = manager.rules[natRuleKey] + require.False(t, found, "nat rule should exist in the manager map") + + exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) + require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) + require.False(t, exists, "income nat rule should not exist") + + _, found = manager.rules[inNatRuleKey] + require.False(t, found, "income nat rule should exist in the manager map") + + }) + } +} diff --git a/client/firewall/iptables/rule.go b/client/firewall/iptables/rule.go index f65030d39..1047c5cf8 100644 --- a/client/firewall/iptables/rule.go +++ b/client/firewall/iptables/rule.go @@ -7,8 +7,7 @@ type Rule struct { specs []string ip string - dst bool - v6 bool + chain string } // GetRuleID returns the rule id diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go new file mode 100644 index 000000000..a9470c9ac --- /dev/null +++ b/client/firewall/iptables/rulestore_linux.go @@ -0,0 +1,50 @@ +package iptables + +type ipList struct { + ips map[string]struct{} +} + +func newIpList(ip string) ipList { + ips := make(map[string]struct{}) + ips[ip] = struct{}{} + + return ipList{ + ips: ips, + } +} + +func (s *ipList) addIP(ip string) { + s.ips[ip] = struct{}{} +} + +type ipsetStore struct { + ipsets map[string]ipList // ipsetName -> ruleset +} + +func newIpsetStore() *ipsetStore { + return &ipsetStore{ + ipsets: make(map[string]ipList), + } +} + +func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { + r, ok := s.ipsets[ipsetName] + return r, ok +} + +func (s *ipsetStore) addIpList(ipsetName string, list ipList) { + s.ipsets[ipsetName] = list +} + +func (s *ipsetStore) deleteIpset(ipsetName string) { + s.ipsets[ipsetName] = ipList{} + delete(s.ipsets, ipsetName) +} + +func (s *ipsetStore) ipsetNames() []string { + names := make([]string, 0, len(s.ipsets)) + for name := range s.ipsets { + names = append(names, name) + } + return names +} diff --git a/client/firewall/firewall.go b/client/firewall/manager/firewall.go similarity index 70% rename from client/firewall/firewall.go rename to client/firewall/manager/firewall.go index 59e672a45..6e4edb63e 100644 --- a/client/firewall/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,9 +1,17 @@ -package firewall +package manager import ( + "fmt" "net" ) +const ( + NatFormat = "netbird-nat-%s" + ForwardingFormat = "netbird-fwd-%s" + InNatFormat = "netbird-nat-in-%s" + InForwardingFormat = "netbird-fwd-in-%s" +) + // Rule abstraction should be implemented by each firewall manager // // Each firewall type for different OS can use different type @@ -27,10 +35,8 @@ const ( type Action int const ( - // ActionUnknown is a unknown action - ActionUnknown Action = iota // ActionAccept is the action to accept a packet - ActionAccept + ActionAccept Action = iota // ActionDrop is the action to drop a packet ActionDrop ) @@ -56,16 +62,27 @@ type Manager interface { action Action, ipsetName string, comment string, - ) (Rule, error) + ) ([]Rule, error) // DeleteRule from the firewall by rule definition DeleteRule(rule Rule) error + // IsServerRouteSupported returns true if the firewall supports server side routing operations + IsServerRouteSupported() bool + + // InsertRoutingRules inserts a routing firewall rule + InsertRoutingRules(pair RouterPair) error + + // RemoveRoutingRules removes a routing firewall rule + RemoveRoutingRules(pair RouterPair) error + // Reset firewall to the default state Reset() error // Flush the changes to firewall controller Flush() error - - // TODO: migrate routemanager firewal actions to this interface +} + +func GenKey(format string, input string) string { + return fmt.Sprintf(format, input) } diff --git a/client/firewall/port.go b/client/firewall/manager/port.go similarity index 98% rename from client/firewall/port.go rename to client/firewall/manager/port.go index 7681f29c3..9061c1e63 100644 --- a/client/firewall/port.go +++ b/client/firewall/manager/port.go @@ -1,4 +1,4 @@ -package firewall +package manager import ( "strconv" diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go new file mode 100644 index 000000000..b63a9f104 --- /dev/null +++ b/client/firewall/manager/routerpair.go @@ -0,0 +1,18 @@ +package manager + +type RouterPair struct { + ID string + Source string + Destination string + Masquerade bool +} + +func GetInPair(pair RouterPair) RouterPair { + return RouterPair{ + ID: pair.ID, + // invert Source/Destination + Source: pair.Destination, + Destination: pair.Source, + Masquerade: pair.Masquerade, + } +} diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go new file mode 100644 index 000000000..4e0d42a5e --- /dev/null +++ b/client/firewall/nftables/acl_linux.go @@ -0,0 +1,1121 @@ +package nftables + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "time" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/iface" +) + +const ( + + // rules chains contains the effective ACL rules + chainNameInputRules = "netbird-acl-input-rules" + chainNameOutputRules = "netbird-acl-output-rules" + + // filter chains contains the rules that jump to the rules chains + chainNameInputFilter = "netbird-acl-input-filter" + chainNameOutputFilter = "netbird-acl-output-filter" + chainNameForwardFilter = "netbird-acl-forward-filter" + + allowNetbirdInputRuleID = "allow Netbird incoming traffic" +) + +var ( + anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00} +) + +type AclManager struct { + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routeingFwChainName string + + workTable *nftables.Table + chainInputRules *nftables.Chain + chainOutputRules *nftables.Chain + chainFwFilter *nftables.Chain + chainPrerouting *nftables.Chain + + ipsetStore *ipsetStore + rules map[string]*Rule +} + +// iFaceMapper defines subset methods of interface required for manager +type iFaceMapper interface { + Name() string + Address() iface.WGAddress +} + +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { + // sConn is used for creating sets and adding/removing elements from them + // it's differ then rConn (which does create new conn for each flush operation) + // and is permanent. Using same connection for booth type of operations + // overloads netlink with high amount of rules ( > 10000) + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + m := &AclManager{ + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routeingFwChainName: routeingFwChainName, + + ipsetStore: newIpsetStore(), + rules: make(map[string]*Rule), + } + + err = m.createDefaultChains() + if err != nil { + return nil, err + } + + return m, nil +} + +// AddFiltering rule to the firewall +// +// If comment argument is empty firewall manager should set +// rule ID as comment for the rule +func (m *AclManager) AddFiltering( + ip net.IP, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + direction firewall.RuleDirection, + action firewall.Action, + ipsetName string, + comment string, +) ([]firewall.Rule, error) { + var ipset *nftables.Set + if ipsetName != "" { + var err error + ipset, err = m.addIpToSet(ipsetName, ip) + if err != nil { + return nil, err + } + } + + newRules := make([]firewall.Rule, 0, 2) + ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment) + if err != nil { + return nil, err + } + + newRules = append(newRules, ioRule) + if !shouldAddToPrerouting(proto, dPort, direction) { + return newRules, nil + } + + preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip) + if err != nil { + return newRules, err + } + newRules = append(newRules, preroutingRule) + return newRules, nil +} + +// DeleteRule from the firewall by rule definition +func (m *AclManager) DeleteRule(rule firewall.Rule) error { + r, ok := rule.(*Rule) + if !ok { + return fmt.Errorf("invalid rule type") + } + + if r.nftSet == nil { + err := m.rConn.DelRule(r.nftRule) + if err != nil { + log.Errorf("failed to delete rule: %v", err) + } + delete(m.rules, r.GetRuleID()) + return m.rConn.Flush() + } + + ips, ok := m.ipsetStore.ips(r.nftSet.Name) + if !ok { + err := m.rConn.DelRule(r.nftRule) + if err != nil { + log.Errorf("failed to delete rule: %v", err) + } + delete(m.rules, r.GetRuleID()) + return m.rConn.Flush() + } + if _, ok := ips[r.ip.String()]; ok { + err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) + if err != nil { + log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err) + } + if err := m.sConn.Flush(); err != nil { + log.Debugf("flush error of set delete element, %s", r.nftSet.Name) + return err + } + m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip) + } + + // if after delete, set still contains other IPs, + // no need to delete firewall rule and we should exit here + if len(ips) > 0 { + return nil + } + + err := m.rConn.DelRule(r.nftRule) + if err != nil { + log.Errorf("failed to delete rule: %v", err) + } + err = m.rConn.Flush() + if err != nil { + return err + } + + delete(m.rules, r.GetRuleID()) + m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name) + + if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) { + return nil + } + + // we delete last IP from the set, that means we need to delete + // set itself and associated firewall rule too + m.rConn.FlushSet(r.nftSet) + m.rConn.DelSet(r.nftSet) + m.ipsetStore.deleteIpset(r.nftSet.Name) + return nil +} + +// Flush rule/chain/set operations from the buffer +// +// Method also get all rules after flush and refreshes handle values in the rulesets +func (m *AclManager) Flush() error { + if err := m.flushWithBackoff(); err != nil { + return err + } + + if err := m.refreshRuleHandles(m.chainInputRules); err != nil { + log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) + } + + if err := m.refreshRuleHandles(m.chainOutputRules); err != nil { + log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) + } + + if err := m.refreshRuleHandles(m.chainPrerouting); err != nil { + log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err) + } + + return nil +} + +func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { + ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset) + if r, ok := m.rules[ruleId]; ok { + return &Rule{ + r.nftRule, + r.nftSet, + r.ruleID, + ip, + }, nil + } + + ifaceKey := expr.MetaKeyIIFNAME + if direction == firewall.RuleDirectionOUT { + ifaceKey = expr.MetaKeyOIFNAME + } + expressions := []expr.Any{ + &expr.Meta{Key: ifaceKey, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + } + + if proto != firewall.ProtocolALL { + expressions = append(expressions, &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: uint32(9), + Len: uint32(1), + }) + + var protoData []byte + switch proto { + case firewall.ProtocolTCP: + protoData = []byte{unix.IPPROTO_TCP} + case firewall.ProtocolUDP: + protoData = []byte{unix.IPPROTO_UDP} + case firewall.ProtocolICMP: + protoData = []byte{unix.IPPROTO_ICMP} + default: + return nil, fmt.Errorf("unsupported protocol: %s", proto) + } + expressions = append(expressions, &expr.Cmp{ + Register: 1, + Op: expr.CmpOpEq, + Data: protoData, + }) + } + + rawIP := ip.To4() + // check if rawIP contains zeroed IPv4 0.0.0.0 value + // in that case not add IP match expression into the rule definition + if !bytes.HasPrefix(anyIP, rawIP) { + // source address position + addrOffset := uint32(12) + if direction == firewall.RuleDirectionOUT { + addrOffset += 4 // is ipv4 address length + } + + expressions = append(expressions, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: addrOffset, + Len: 4, + }, + ) + // add individual IP for match if no ipset defined + if ipset == nil { + expressions = append(expressions, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: rawIP, + }, + ) + } else { + expressions = append(expressions, + &expr.Lookup{ + SourceRegister: 1, + SetName: ipset.Name, + SetID: ipset.ID, + }, + ) + } + } + + if sPort != nil && len(sPort.Values) != 0 { + expressions = append(expressions, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 0, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: encodePort(*sPort), + }, + ) + } + + if dPort != nil && len(dPort.Values) != 0 { + expressions = append(expressions, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: encodePort(*dPort), + }, + ) + } + + switch action { + case firewall.ActionAccept: + expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) + case firewall.ActionDrop: + expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) + } + + userData := []byte(strings.Join([]string{ruleId, comment}, " ")) + + var chain *nftables.Chain + if direction == firewall.RuleDirectionIN { + chain = m.chainInputRules + } else { + chain = m.chainOutputRules + } + nftRule := m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chain, + Position: 0, + Exprs: expressions, + UserData: userData, + }) + + rule := &Rule{ + nftRule: nftRule, + nftSet: ipset, + ruleID: ruleId, + ip: ip, + } + m.rules[ruleId] = rule + if ipset != nil { + m.ipsetStore.AddReferenceToIpset(ipset.Name) + } + return rule, nil +} + +func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) { + var protoData []byte + switch proto { + case firewall.ProtocolTCP: + protoData = []byte{unix.IPPROTO_TCP} + case firewall.ProtocolUDP: + protoData = []byte{unix.IPPROTO_UDP} + case firewall.ProtocolICMP: + protoData = []byte{unix.IPPROTO_ICMP} + default: + return nil, fmt.Errorf("unsupported protocol: %s", proto) + } + + ruleId := generateRuleIdForMangle(ipset, ip, proto, port) + if r, ok := m.rules[ruleId]; ok { + return &Rule{ + r.nftRule, + r.nftSet, + r.ruleID, + ip, + }, nil + } + + var ipExpression expr.Any + // add individual IP for match if no ipset defined + rawIP := ip.To4() + if ipset == nil { + ipExpression = &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: rawIP, + } + } else { + ipExpression = &expr.Lookup{ + SourceRegister: 1, + SetName: ipset.Name, + SetID: ipset.ID, + } + } + + expressions := []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + ipExpression, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: m.wgIface.Address().IP.To4(), + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: uint32(9), + Len: uint32(1), + }, + &expr.Cmp{ + Register: 1, + Op: expr.CmpOpEq, + Data: protoData, + }, + } + + if port != nil { + expressions = append(expressions, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: encodePort(*port), + }, + ) + } + + expressions = append(expressions, + &expr.Immediate{ + Register: 1, + Data: postroutingMark, + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + SourceRegister: true, + Register: 1, + }, + ) + + nftRule := m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: m.chainPrerouting, + Position: 0, + Exprs: expressions, + UserData: []byte(ruleId), + }) + + if err := m.rConn.Flush(); err != nil { + return nil, fmt.Errorf("flush insert rule: %v", err) + } + + rule := &Rule{ + nftRule: nftRule, + nftSet: ipset, + ruleID: ruleId, + ip: ip, + } + + m.rules[ruleId] = rule + if ipset != nil { + m.ipsetStore.AddReferenceToIpset(ipset.Name) + } + return rule, nil +} + +func (m *AclManager) createDefaultChains() (err error) { + // chainNameInputRules + chain := m.createChain(chainNameInputRules) + err = m.rConn.Flush() + if err != nil { + log.Debugf("failed to create chain (%s): %s", chain.Name, err) + return err + } + m.chainInputRules = chain + + // chainNameOutputRules + chain = m.createChain(chainNameOutputRules) + err = m.rConn.Flush() + if err != nil { + log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err) + return err + } + m.chainOutputRules = chain + + // netbird-acl-input-filter + // type filter hook input priority filter; policy accept; + chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) + //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept + m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME) + m.addFwdAllow(chain, expr.MetaKeyIIFNAME) + m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules + m.addDropExpressions(chain, expr.MetaKeyIIFNAME) + err = m.rConn.Flush() + if err != nil { + log.Debugf("failed to create chain (%s): %s", chain.Name, err) + return err + } + + // netbird-acl-output-filter + // type filter hook output priority filter; policy accept; + chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) + m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME) + m.addFwdAllow(chain, expr.MetaKeyOIFNAME) + m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules + m.addDropExpressions(chain, expr.MetaKeyOIFNAME) + err = m.rConn.Flush() + if err != nil { + log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err) + return err + } + + // netbird-acl-forward-filter + m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward() // to + m.addMarkAccept() + m.addJumpRuleToInputChain() // to netbird-acl-input-rules + m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + err = m.rConn.Flush() + if err != nil { + log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) + return err + } + + // netbird-acl-output-filter + // type filter hook output priority filter; policy accept; + m.chainPrerouting = m.createPreroutingMangle() + err = m.rConn.Flush() + if err != nil { + log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err) + return err + } + return nil +} + +func (m *AclManager) addJumpRulesToRtForward() { + expressions := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.routeingFwChainName, + }, + } + + _ = m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: m.chainFwFilter, + Exprs: expressions, + }) + + expressions = []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.routeingFwChainName, + }, + } + + _ = m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: m.chainFwFilter, + Exprs: expressions, + }) +} + +func (m *AclManager) addMarkAccept() { + // oifname "wt0" meta mark 0x000007e4 accept + // iifname "wt0" meta mark 0x000007e4 accept + ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME} + for _, iface := range ifaces { + expressions := []expr.Any{ + &expr.Meta{Key: iface, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: postroutingMark, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + _ = m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: m.chainFwFilter, + Exprs: expressions, + }) + } +} + +func (m *AclManager) createChain(name string) *nftables.Chain { + chain := &nftables.Chain{ + Name: name, + Table: m.workTable, + } + + chain = m.rConn.AddChain(chain) + return chain +} + +func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { + polAccept := nftables.ChainPolicyAccept + chain := &nftables.Chain{ + Name: name, + Table: m.workTable, + Hooknum: hookNum, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + Policy: &polAccept, + } + + return m.rConn.AddChain(chain) +} + +func (m *AclManager) createPreroutingMangle() *nftables.Chain { + polAccept := nftables.ChainPolicyAccept + chain := &nftables.Chain{ + Name: "netbird-acl-prerouting-filter", + Table: m.workTable, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + Policy: &polAccept, + } + + chain = m.rConn.AddChain(chain) + + ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) + expressions := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: m.wgIface.Address().IP.To4(), + }, + &expr.Immediate{ + Register: 1, + Data: postroutingMark, + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + SourceRegister: true, + Register: 1, + }, + } + _ = m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: chain, + Exprs: expressions, + }) + chain = m.rConn.AddChain(chain) + return chain +} + +func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { + expressions := []expr.Any{ + &expr.Meta{Key: ifaceKey, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Verdict{Kind: expr.VerdictDrop}, + } + _ = m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: chain, + Exprs: expressions, + }) + return nil +} + +func (m *AclManager) addJumpRuleToInputChain() { + expressions := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + } + + _ = m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: m.chainFwFilter, + Exprs: expressions, + }) +} + +func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) { + ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) + var srcOp, dstOp expr.CmpOp + if netIfName == expr.MetaKeyIIFNAME { + srcOp = expr.CmpOpEq + dstOp = expr.CmpOpNeq + } else { + srcOp = expr.CmpOpNeq + dstOp = expr.CmpOpEq + } + expressions := []expr.Any{ + &expr.Meta{Key: netIfName, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: srcOp, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: dstOp, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + _ = m.rConn.AddRule(&nftables.Rule{ + Table: chain.Table, + Chain: chain, + Exprs: expressions, + }) +} + +func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { + ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) + var srcOp, dstOp expr.CmpOp + if iifname == expr.MetaKeyIIFNAME { + srcOp = expr.CmpOpNeq + dstOp = expr.CmpOpEq + } else { + srcOp = expr.CmpOpEq + dstOp = expr.CmpOpNeq + } + expressions := []expr.Any{ + &expr.Meta{Key: iifname, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: srcOp, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: dstOp, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + _ = m.rConn.AddRule(&nftables.Rule{ + Table: chain.Table, + Chain: chain, + Exprs: expressions, + }) +} + +func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { + ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) + expressions := []expr.Any{ + &expr.Meta{Key: ifaceKey, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Payload{ + DestRegister: 2, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Xor: []byte{0x0, 0x0, 0x0, 0x0}, + Mask: m.wgIface.Address().Network.Mask, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: ip.Unmap().AsSlice(), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: to, + }, + } + _ = m.rConn.AddRule(&nftables.Rule{ + Table: chain.Table, + Chain: chain, + Exprs: expressions, + }) +} + +func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) { + ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName) + rawIP := ip.To4() + if err != nil { + if ipset, err = m.createSet(m.workTable, ipsetName); err != nil { + return nil, fmt.Errorf("get set name: %v", err) + } + + m.ipsetStore.newIpset(ipset.Name) + } + + if m.ipsetStore.IsIpInSet(ipset.Name, ip) { + return ipset, nil + } + + if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil { + return nil, fmt.Errorf("add set element for the first time: %v", err) + } + + m.ipsetStore.AddIpToSet(ipset.Name, ip) + + if err := m.sConn.Flush(); err != nil { + return nil, fmt.Errorf("flush add elements: %v", err) + } + + return ipset, nil +} + +// createSet in given table by name +func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) { + ipset := &nftables.Set{ + Name: name, + Table: table, + Dynamic: true, + KeyType: nftables.TypeIPAddr, + } + + if err := m.rConn.AddSet(ipset, nil); err != nil { + return nil, fmt.Errorf("create set: %v", err) + } + + if err := m.rConn.Flush(); err != nil { + return nil, fmt.Errorf("flush created set: %v", err) + } + + return ipset, nil +} + +func (m *AclManager) flushWithBackoff() (err error) { + backoff := 4 + backoffTime := 1000 * time.Millisecond + for i := 0; ; i++ { + err = m.rConn.Flush() + if err != nil { + if !strings.Contains(err.Error(), "busy") { + return + } + log.Error("failed to flush nftables, retrying...") + if i == backoff-1 { + return err + } + time.Sleep(backoffTime) + backoffTime *= 2 + continue + } + break + } + return +} + +func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { + if m.workTable == nil || chain == nil { + return nil + } + + list, err := m.rConn.GetRules(m.workTable, chain) + if err != nil { + return err + } + + for _, rule := range list { + if len(rule.UserData) == 0 { + continue + } + split := bytes.Split(rule.UserData, []byte(" ")) + r, ok := m.rules[string(split[0])] + if ok { + *r.nftRule = *rule + } + } + + return nil +} + +func generateRuleId( + ip net.IP, + sPort *firewall.Port, + dPort *firewall.Port, + direction firewall.RuleDirection, + action firewall.Action, + ipset *nftables.Set, +) string { + rulesetID := ":" + strconv.Itoa(int(direction)) + ":" + if sPort != nil { + rulesetID += sPort.String() + } + rulesetID += ":" + if dPort != nil { + rulesetID += dPort.String() + } + rulesetID += ":" + rulesetID += strconv.Itoa(int(action)) + if ipset == nil { + return "ip:" + ip.String() + rulesetID + } + return "set:" + ipset.Name + rulesetID +} +func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string { + // case of icmp port is empty + var p string + if port != nil { + p = port.String() + } + if ipset != nil { + return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p) + } else { + return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p) + } +} + +func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { + if proto == "all" { + return false + } + + if direction != firewall.RuleDirectionIN { + return false + } + + if dPort == nil && proto != firewall.ProtocolICMP { + return false + } + return true +} + +func encodePort(port firewall.Port) []byte { + bs := make([]byte, 2) + binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) + return bs +} + +func ifname(n string) []byte { + b := make([]byte, 16) + copy(b, []byte(n+"\x00")) + return b +} diff --git a/client/firewall/nftables/ipsetstore_linux.go b/client/firewall/nftables/ipsetstore_linux.go new file mode 100644 index 000000000..a6c2e9496 --- /dev/null +++ b/client/firewall/nftables/ipsetstore_linux.go @@ -0,0 +1,85 @@ +package nftables + +import ( + "net" +) + +type ipsetStore struct { + ipsetReference map[string]int + ipsets map[string]map[string]struct{} // ipsetName -> list of ips +} + +func newIpsetStore() *ipsetStore { + return &ipsetStore{ + ipsetReference: make(map[string]int), + ipsets: make(map[string]map[string]struct{}), + } +} + +func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) { + r, ok := s.ipsets[ipsetName] + return r, ok +} + +func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} { + s.ipsetReference[ipsetName] = 0 + ipList := make(map[string]struct{}) + s.ipsets[ipsetName] = ipList + return ipList +} + +func (s *ipsetStore) deleteIpset(ipsetName string) { + delete(s.ipsetReference, ipsetName) + delete(s.ipsets, ipsetName) +} + +func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) { + ipList, ok := s.ipsets[ipsetName] + if !ok { + return + } + delete(ipList, ip.String()) +} + +func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) { + ipList, ok := s.ipsets[ipsetName] + if !ok { + return + } + ipList[ip.String()] = struct{}{} +} + +func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool { + ipList, ok := s.ipsets[ipsetName] + if !ok { + return false + } + _, ok = ipList[ip.String()] + return ok +} + +func (s *ipsetStore) AddReferenceToIpset(ipsetName string) { + s.ipsetReference[ipsetName]++ +} + +func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) { + r, ok := s.ipsetReference[ipsetName] + if !ok { + return + } + if r == 0 { + return + } + s.ipsetReference[ipsetName]-- +} + +func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool { + if _, ok := s.ipsetReference[ipsetName]; !ok { + return false + } + if s.ipsetReference[ipsetName] == 0 { + return false + } + + return true +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 93379bad8..fad2d7804 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -2,90 +2,52 @@ package nftables import ( "bytes" - "encoding/binary" + "context" "fmt" "net" - "net/netip" - "strconv" - "strings" "sync" - "time" "github.com/google/nftables" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" - "golang.org/x/sys/unix" - fw "github.com/netbirdio/netbird/client/firewall" - "github.com/netbirdio/netbird/iface" + firewall "github.com/netbirdio/netbird/client/firewall/manager" ) const ( - // FilterTableName is the name of the table that is used for filtering by the Netbird client - FilterTableName = "netbird-acl" - - // FilterInputChainName is the name of the chain that is used for filtering incoming packets - FilterInputChainName = "netbird-acl-input-filter" - - // FilterOutputChainName is the name of the chain that is used for filtering outgoing packets - FilterOutputChainName = "netbird-acl-output-filter" - - AllowNetbirdInputRuleID = "allow Netbird incoming traffic" + // tableName is the name of the table that is used for filtering by the Netbird client + tableName = "netbird" ) -var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - // Manager of iptables firewall type Manager struct { - mutex sync.Mutex - - rConn *nftables.Conn - sConn *nftables.Conn - tableIPv4 *nftables.Table - tableIPv6 *nftables.Table - - filterInputChainIPv4 *nftables.Chain - filterOutputChainIPv4 *nftables.Chain - - filterInputChainIPv6 *nftables.Chain - filterOutputChainIPv6 *nftables.Chain - - rulesetManager *rulesetManager - setRemovedIPs map[string]struct{} - setRemoved map[string]*nftables.Set - + mutex sync.Mutex + rConn *nftables.Conn wgIface iFaceMapper -} -// iFaceMapper defines subset methods of interface required for manager -type iFaceMapper interface { - Name() string - Address() iface.WGAddress + router *router + aclManager *AclManager } // Create nftables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { - // sConn is used for creating sets and adding/removing elements from them - // it's differ then rConn (which does create new conn for each flush operation) - // and is permanent. Using same connection for booth type of operations - // overloads netlink with high amount of rules ( > 10000) - sConn, err := nftables.New(nftables.AsLasting()) +func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { + m := &Manager{ + rConn: &nftables.Conn{}, + wgIface: wgIface, + } + + workTable, err := m.createWorkTable() if err != nil { return nil, err } - m := &Manager{ - rConn: &nftables.Conn{}, - sConn: sConn, - - rulesetManager: newRuleManager(), - setRemovedIPs: map[string]struct{}{}, - setRemoved: map[string]*nftables.Set{}, - - wgIface: wgIface, + m.router, err = newRouter(context, workTable) + if err != nil { + return nil, err } - if err := m.Reset(); err != nil { + m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName()) + if err != nil { return nil, err } @@ -98,649 +60,58 @@ func Create(wgIface iFaceMapper) (*Manager, error) { // rule ID as comment for the rule func (m *Manager) AddFiltering( ip net.IP, - proto fw.Protocol, - sPort *fw.Port, - dPort *fw.Port, - direction fw.RuleDirection, - action fw.Action, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + direction firewall.RuleDirection, + action firewall.Action, ipsetName string, comment string, -) (fw.Rule, error) { +) ([]firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - var ( - err error - ipset *nftables.Set - table *nftables.Table - chain *nftables.Chain - ) - - if direction == fw.RuleDirectionOUT { - table, chain, err = m.chain( - ip, - FilterOutputChainName, - nftables.ChainHookOutput, - nftables.ChainPriorityFilter, - nftables.ChainTypeFilter) - } else { - table, chain, err = m.chain( - ip, - FilterInputChainName, - nftables.ChainHookInput, - nftables.ChainPriorityFilter, - nftables.ChainTypeFilter) - } - if err != nil { - return nil, err - } - rawIP := ip.To4() if rawIP == nil { - rawIP = ip.To16() + return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) } - rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName) - - if ipsetName != "" { - // if we already have set with given name, just add ip to the set - // and return rule with new ID in other case let's create rule - // with fresh created set and set element - - var isSetNew bool - ipset, err = m.rConn.GetSetByName(table, ipsetName) - if err != nil { - if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil { - return nil, fmt.Errorf("get set name: %v", err) - } - isSetNew = true - } - - if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil { - return nil, fmt.Errorf("add set element for the first time: %v", err) - } - if err := m.sConn.Flush(); err != nil { - return nil, fmt.Errorf("flush add elements: %v", err) - } - - if !isSetNew { - // if we already have nftables rules with set for given direction - // just add new rule to the ruleset and return new fw.Rule object - - if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok { - return m.rulesetManager.addRule(ruleset, rawIP) - } - // if ipset exists but it is not linked to rule for given direction - // create new rule for direction and bind ipset to it later - } - } - - ifaceKey := expr.MetaKeyIIFNAME - if direction == fw.RuleDirectionOUT { - ifaceKey = expr.MetaKeyOIFNAME - } - expressions := []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - } - - if proto != "all" { - expressions = append(expressions, &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), - Len: uint32(1), - }) - - var protoData []byte - switch proto { - case fw.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case fw.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case fw.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) - } - expressions = append(expressions, &expr.Cmp{ - Register: 1, - Op: expr.CmpOpEq, - Data: protoData, - }) - } - - // check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value - // in that case not add IP match expression into the rule definition - if !bytes.HasPrefix(anyIP, rawIP) { - // source address position - addrLen := uint32(len(rawIP)) - addrOffset := uint32(12) - if addrLen == 16 { - addrOffset = 8 - } - - // change to destination address position if need - if direction == fw.RuleDirectionOUT { - addrOffset += addrLen - } - - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: addrOffset, - Len: addrLen, - }, - ) - // add individual IP for match if no ipset defined - if ipset == nil { - expressions = append(expressions, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: rawIP, - }, - ) - } else { - expressions = append(expressions, - &expr.Lookup{ - SourceRegister: 1, - SetName: ipsetName, - SetID: ipset.ID, - }, - ) - } - } - - if sPort != nil && len(sPort.Values) != 0 { - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseTransportHeader, - Offset: 0, - Len: 2, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: encodePort(*sPort), - }, - ) - } - - if dPort != nil && len(dPort.Values) != 0 { - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseTransportHeader, - Offset: 2, - Len: 2, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: encodePort(*dPort), - }, - ) - } - - if action == fw.ActionAccept { - expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) - } else { - expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) - } - - userData := []byte(strings.Join([]string{rulesetID, comment}, " ")) - - rule := m.rConn.InsertRule(&nftables.Rule{ - Table: table, - Chain: chain, - Position: 0, - Exprs: expressions, - UserData: userData, - }) - if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf("flush insert rule: %v", err) - } - - ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset) - return m.rulesetManager.addRule(ruleset, rawIP) -} - -// getRulesetID returns ruleset ID based on given parameters -func (m *Manager) getRulesetID( - ip net.IP, - proto fw.Protocol, - sPort *fw.Port, - dPort *fw.Port, - direction fw.RuleDirection, - action fw.Action, - ipsetName string, -) string { - rulesetID := ":" + strconv.Itoa(int(direction)) + ":" - if sPort != nil { - rulesetID += sPort.String() - } - rulesetID += ":" - if dPort != nil { - rulesetID += dPort.String() - } - rulesetID += ":" - rulesetID += strconv.Itoa(int(action)) - if ipsetName == "" { - return "ip:" + ip.String() + rulesetID - } - return "set:" + ipsetName + rulesetID -} - -// createSet in given table by name -func (m *Manager) createSet( - table *nftables.Table, - rawIP []byte, - name string, -) (*nftables.Set, error) { - keyType := nftables.TypeIPAddr - if len(rawIP) == 16 { - keyType = nftables.TypeIP6Addr - } - // else we create new ipset and continue creating rule - ipset := &nftables.Set{ - Name: name, - Table: table, - Dynamic: true, - KeyType: keyType, - } - - if err := m.rConn.AddSet(ipset, nil); err != nil { - return nil, fmt.Errorf("create set: %v", err) - } - - if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf("flush created set: %v", err) - } - - return ipset, nil -} - -// chain returns the chain for the given IP address with specific settings -func (m *Manager) chain( - ip net.IP, - name string, - hook nftables.ChainHook, - priority nftables.ChainPriority, - cType nftables.ChainType, -) (*nftables.Table, *nftables.Chain, error) { - var err error - - getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) { - if c != nil { - return c, nil - } - return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType) - } - - if ip.To4() != nil { - if name == FilterInputChainName { - m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4) - return m.tableIPv4, m.filterInputChainIPv4, err - } - m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4) - return m.tableIPv4, m.filterOutputChainIPv4, err - } - if name == FilterInputChainName { - m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6) - return m.tableIPv4, m.filterInputChainIPv6, err - } - m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6) - return m.tableIPv4, m.filterOutputChainIPv6, err -} - -// table returns the table for the given family of the IP address -func (m *Manager) table( - family nftables.TableFamily, tableName string, -) (*nftables.Table, error) { - // we cache access to Netbird ACL table only - if tableName != FilterTableName { - return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName) - } - - if family == nftables.TableFamilyIPv4 { - if m.tableIPv4 != nil { - return m.tableIPv4, nil - } - - table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName) - if err != nil { - return nil, err - } - m.tableIPv4 = table - return m.tableIPv4, nil - } - - if m.tableIPv6 != nil { - return m.tableIPv6, nil - } - - table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName) - if err != nil { - return nil, err - } - m.tableIPv6 = table - return m.tableIPv6, nil -} - -func (m *Manager) createTableIfNotExists( - family nftables.TableFamily, tableName string, -) (*nftables.Table, error) { - tables, err := m.rConn.ListTablesOfFamily(family) - if err != nil { - return nil, fmt.Errorf("list of tables: %w", err) - } - - for _, t := range tables { - if t.Name == tableName { - return t, nil - } - } - - table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) - if err := m.rConn.Flush(); err != nil { - return nil, err - } - return table, nil -} - -func (m *Manager) createChainIfNotExists( - family nftables.TableFamily, - tableName string, - name string, - hooknum nftables.ChainHook, - priority nftables.ChainPriority, - chainType nftables.ChainType, -) (*nftables.Chain, error) { - table, err := m.table(family, tableName) - if err != nil { - return nil, err - } - - chains, err := m.rConn.ListChainsOfTableFamily(family) - if err != nil { - return nil, fmt.Errorf("list of chains: %w", err) - } - - for _, c := range chains { - if c.Name == name && c.Table.Name == table.Name { - return c, nil - } - } - - polAccept := nftables.ChainPolicyAccept - chain := &nftables.Chain{ - Name: name, - Table: table, - Hooknum: hooknum, - Priority: priority, - Type: chainType, - Policy: &polAccept, - } - - chain = m.rConn.AddChain(chain) - - ifaceKey := expr.MetaKeyIIFNAME - shiftDSTAddr := 0 - if name == FilterOutputChainName { - ifaceKey = expr.MetaKeyOIFNAME - shiftDSTAddr = 1 - } - - expressions := []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - } - - mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask) - if m.wgIface.Address().IP.To4() == nil { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16()) - expressions = append(expressions, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(8 + (16 * shiftDSTAddr)), - Len: 16, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 16, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: mask.Unmap().AsSlice(), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{Kind: expr.VerdictAccept}, - ) - } else { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - expressions = append(expressions, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(12 + (4 * shiftDSTAddr)), - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{Kind: expr.VerdictAccept}, - ) - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: table, - Chain: chain, - Exprs: expressions, - }) - - expressions = []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{Kind: expr.VerdictDrop}, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: table, - Chain: chain, - Exprs: expressions, - }) - - if err := m.rConn.Flush(); err != nil { - return nil, err - } - - return chain, nil + return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) } // DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule fw.Rule) error { +func (m *Manager) DeleteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - nativeRule, ok := rule.(*Rule) - if !ok { - return fmt.Errorf("invalid rule type") - } - - if nativeRule.nftRule == nil { - return nil - } - - if nativeRule.nftSet != nil { - // call twice of delete set element raises error - // so we need to check if element is already removed - key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip) - if _, ok := m.setRemovedIPs[key]; !ok { - err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}}) - if err != nil { - log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err) - } - if err := m.sConn.Flush(); err != nil { - return err - } - m.setRemovedIPs[key] = struct{}{} - } - } - - if m.rulesetManager.deleteRule(nativeRule) { - // deleteRule indicates that we still have IP in the ruleset - // it means we should not remove the nftables rule but need to update set - // so we prepare IP to be removed from set on the next flush call - return nil - } - - // ruleset doesn't contain IP anymore (or contains only one), remove nft rule - if err := m.rConn.DelRule(nativeRule.nftRule); err != nil { - log.Errorf("failed to delete rule: %v", err) - } - if err := m.rConn.Flush(); err != nil { - return err - } - nativeRule.nftRule = nil - - if nativeRule.nftSet != nil { - if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok { - m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet - } - nativeRule.nftSet = nil - } - - return nil + return m.aclManager.DeleteRule(rule) } -// Reset firewall to the default state -func (m *Manager) Reset() error { - m.mutex.Lock() - defer m.mutex.Unlock() - - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - for _, c := range chains { - // delete Netbird allow input traffic rule if it exists - if c.Table.Name == "filter" && c.Name == "INPUT" { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } - } - - if c.Name == FilterInputChainName || c.Name == FilterOutputChainName { - m.rConn.DelChain(c) - } - } - - tables, err := m.rConn.ListTables() - if err != nil { - return fmt.Errorf("list of tables: %w", err) - } - for _, t := range tables { - if t.Name == FilterTableName { - m.rConn.DelTable(t) - } - } - - return m.rConn.Flush() +func (m *Manager) IsServerRouteSupported() bool { + return true } -// Flush rule/chain/set operations from the buffer -// -// Method also get all rules after flush and refreshes handle values in the rulesets -func (m *Manager) Flush() error { +func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.flushWithBackoff(); err != nil { - return err - } + return m.router.InsertRoutingRules(pair) +} - // set must be removed after flush rule changes - // otherwise we will get error - for _, s := range m.setRemoved { - m.rConn.FlushSet(s) - m.rConn.DelSet(s) - } +func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { + m.mutex.Lock() + defer m.mutex.Unlock() - if len(m.setRemoved) > 0 { - if err := m.flushWithBackoff(); err != nil { - return err - } - } - - m.setRemovedIPs = map[string]struct{}{} - m.setRemoved = map[string]*nftables.Set{} - - if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil { - log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) - } - - if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil { - log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) - } - - if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil { - log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err) - } - - if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil { - log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err) - } - - return nil + return m.router.RemoveRoutingRules(pair) } // AllowNetbird allows netbird interface traffic +// todo review this method usage func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - tf := nftables.TableFamilyIPv4 - if m.wgIface.Address().IP.To4() == nil { - tf = nftables.TableFamilyIPv6 - } - - chains, err := m.rConn.ListChainsOfTableFamily(tf) + chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) if err != nil { return fmt.Errorf("list of chains: %w", err) } @@ -777,47 +148,75 @@ func (m *Manager) AllowNetbird() error { return nil } -func (m *Manager) flushWithBackoff() (err error) { - backoff := 4 - backoffTime := 1000 * time.Millisecond - for i := 0; ; i++ { - err = m.rConn.Flush() - if err != nil { - if !strings.Contains(err.Error(), "busy") { - return - } - log.Error("failed to flush nftables, retrying...") - if i == backoff-1 { - return err - } - time.Sleep(backoffTime) - backoffTime *= 2 - continue - } - break +// Reset firewall to the default state +func (m *Manager) Reset() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + chains, err := m.rConn.ListChains() + if err != nil { + return fmt.Errorf("list of chains: %w", err) } - return + + for _, c := range chains { + // delete Netbird allow input traffic rule if it exists + if c.Table.Name == "filter" && c.Name == "INPUT" { + rules, err := m.rConn.GetRules(c.Table, c) + if err != nil { + log.Errorf("get rules for chain %q: %v", c.Name, err) + continue + } + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) + } + } + } + } + } + + m.router.ResetForwardRules() + + tables, err := m.rConn.ListTables() + if err != nil { + return fmt.Errorf("list of tables: %w", err) + } + for _, t := range tables { + if t.Name == tableName { + m.rConn.DelTable(t) + } + } + + return m.rConn.Flush() } -func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error { - if table == nil || chain == nil { - return nil - } +// Flush rule/chain/set operations from the buffer +// +// Method also get all rules after flush and refreshes handle values in the rulesets +// todo review this method usage +func (m *Manager) Flush() error { + m.mutex.Lock() + defer m.mutex.Unlock() - list, err := m.rConn.GetRules(table, chain) + return m.aclManager.Flush() +} + +func (m *Manager) createWorkTable() (*nftables.Table, error) { + tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return err + return nil, fmt.Errorf("list of tables: %w", err) } - for _, rule := range list { - if len(rule.UserData) != 0 { - if err := m.rulesetManager.setNftRuleHandle(rule); err != nil { - log.Errorf("failed to set rule handle: %v", err) - } + for _, t := range tables { + if t.Name == tableName { + m.rConn.DelTable(t) } } - return nil + table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + err = m.rConn.Flush() + return table, err } func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { @@ -835,7 +234,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Kind: expr.VerdictAccept, }, }, - UserData: []byte(AllowNetbirdInputRuleID), + UserData: []byte(allowNetbirdInputRuleID), } _ = m.rConn.InsertRule(rule) } @@ -857,15 +256,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable } return nil } - -func encodePort(port fw.Port) []byte { - bs := make([]byte, 2) - binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) - return bs -} - -func ifname(n string) []byte { - b := make([]byte, 16) - copy(b, []byte(n+"\x00")) - return b -} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 0a5c499b2..74ddaf6e1 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,6 +1,7 @@ package nftables import ( + "context" "fmt" "net" "net/netip" @@ -12,7 +13,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sys/unix" - fw "github.com/netbirdio/netbird/client/firewall" + fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" ) @@ -53,7 +54,7 @@ func TestNftablesManager(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(context.Background(), mock) require.NoError(t, err) time.Sleep(time.Second * 3) @@ -82,14 +83,10 @@ func TestNftablesManager(t *testing.T) { err = manager.Flush() require.NoError(t, err, "failed to flush") - rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) + rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - // test expectations: - // 1) regular rule - // 2) "accept extra routed traffic rule" for the interface - // 3) "drop all rule" for the interface - require.Len(t, rules, 3, "expected 3 rules") + require.Len(t, rules, 1, "expected 1 rules") ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() @@ -137,18 +134,17 @@ func TestNftablesManager(t *testing.T) { } require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") - err = manager.DeleteRule(rule) - require.NoError(t, err, "failed to delete rule") + for _, r := range rule { + err = manager.DeleteRule(r) + require.NoError(t, err, "failed to delete rule") + } err = manager.Flush() require.NoError(t, err, "failed to flush") - rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) + rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - // test expectations: - // 1) "accept extra routed traffic rule" for the interface - // 2) "drop all rule" for the interface - require.Len(t, rules, 2, "expected 2 rules after deletion") + require.Len(t, rules, 0, "expected 0 rules after deletion") err = manager.Reset() require.NoError(t, err, "failed to reset") @@ -173,7 +169,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(context.Background(), mock) require.NoError(t, err) time.Sleep(time.Second * 3) diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go new file mode 100644 index 000000000..381136e50 --- /dev/null +++ b/client/firewall/nftables/route_linux.go @@ -0,0 +1,413 @@ +package nftables + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "net/netip" + + "github.com/google/nftables" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + chainNameRouteingFw = "netbird-rt-fwd" + chainNameRoutingNat = "netbird-rt-nat" + + userDataAcceptForwardRuleSrc = "frwacceptsrc" + userDataAcceptForwardRuleDst = "frwacceptdst" +) + +// some presets for building nftable rules +var ( + zeroXor = binaryutil.NativeEndian.PutUint32(0) + + exprCounterAccept = []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") +) + +type router struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + workTable *nftables.Table + filterTable *nftables.Table + chains map[string]*nftables.Chain + // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules + rules map[string]*nftables.Rule + isDefaultFwdRulesEnabled bool +} + +func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) { + ctx, cancel := context.WithCancel(parentCtx) + + r := &router{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + workTable: workTable, + chains: make(map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + var err error + r.filterTable, err = r.loadFilterTable() + if err != nil { + if errors.Is(err, errFilterTableNotFound) { + log.Warnf("table 'filter' not found for forward rules") + } else { + return nil, err + } + } + + err = r.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.createContainers() + if err != nil { + log.Errorf("failed to create containers for route: %s", err) + } + return r, err +} + +func (r *router) RouteingFwChainName() string { + return chainNameRouteingFw +} + +// ResetForwardRules cleans existing nftables default forward rules from the system +func (r *router) ResetForwardRules() { + err := r.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to reset forward rules: %s", err) + } +} + +func (r *router) loadFilterTable() (*nftables.Table, error) { + tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + } + + for _, table := range tables { + if table.Name == "filter" { + return table, nil + } + } + + return nil, errFilterTableNotFound +} + +func (r *router) createContainers() error { + + r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRouteingFw, + Table: r.workTable, + }) + + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingNat, + Table: r.workTable, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + + err := r.refreshRulesMap() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + return nil +} + +// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain +func (r *router) InsertRoutingRules(pair manager.RouterPair) error { + err := r.refreshRulesMap() + if err != nil { + return err + } + + err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) + if err != nil { + return err + } + err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) + if err != nil { + return err + } + + if pair.Masquerade { + err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) + if err != nil { + return err + } + err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) + if err != nil { + return err + } + } + + if r.filterTable != nil && !r.isDefaultFwdRulesEnabled { + log.Debugf("add default accept forward rule") + r.acceptForwardRule(pair.Source) + } + + err = r.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err) + } + return nil +} + +// insertRoutingRule inserts a nftable rule to the conn client flush queue +func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + var expression []expr.Any + if isNat { + expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic + } else { + expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic + } + + ruleKey := manager.GenKey(format, pair.ID) + + _, exists := r.rules[ruleKey] + if exists { + err := r.removeRoutingRule(format, pair) + if err != nil { + return err + } + } + + r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainName], + Exprs: expression, + UserData: []byte(ruleKey), + }) + return nil +} + +func (r *router) acceptForwardRule(sourceNetwork string) { + src := generateCIDRMatcherExpressions(true, sourceNetwork) + dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0") + + var exprs []expr.Any + exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic + Kind: expr.VerdictAccept, + })...) + + rule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: exprs, + UserData: []byte(userDataAcceptForwardRuleSrc), + } + + r.conn.AddRule(rule) + + src = generateCIDRMatcherExpressions(true, "0.0.0.0/0") + dst = generateCIDRMatcherExpressions(false, sourceNetwork) + + exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic + Kind: expr.VerdictAccept, + })...) + + rule = &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: exprs, + UserData: []byte(userDataAcceptForwardRuleDst), + } + r.conn.AddRule(rule) + r.isDefaultFwdRulesEnabled = true +} + +// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains +func (r *router) RemoveRoutingRules(pair manager.RouterPair) error { + err := r.refreshRulesMap() + if err != nil { + return err + } + + err = r.removeRoutingRule(manager.ForwardingFormat, pair) + if err != nil { + return err + } + + err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair)) + if err != nil { + return err + } + + err = r.removeRoutingRule(manager.NatFormat, pair) + if err != nil { + return err + } + + err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair)) + if err != nil { + return err + } + + if len(r.rules) == 0 { + err := r.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + } + + err = r.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + } + log.Debugf("nftables: removed rules for %s", pair.Destination) + return nil +} + +// removeRoutingRule add a nftable rule to the removal queue and delete from rules map +func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error { + ruleKey := manager.GenKey(format, pair.ID) + + rule, found := r.rules[ruleKey] + if found { + ruleType := "forwarding" + if rule.Chain.Type == nftables.ChainTypeNAT { + ruleType = "nat" + } + + err := r.conn.DelRule(rule) + if err != nil { + return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err) + } + + log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination) + + delete(r.rules, ruleKey) + } + return nil +} + +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules +func (r *router) refreshRulesMap() error { + for _, chain := range r.chains { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("nftables: unable to list rules: %v", err) + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + r.rules[string(rule.UserData)] = rule + } + } + } + return nil +} + +func (r *router) cleanUpDefaultForwardRules() error { + if r.filterTable == nil { + r.isDefaultFwdRulesEnabled = false + return nil + } + + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return err + } + + var rules []*nftables.Rule + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name { + continue + } + if chain.Name != "FORWARD" { + continue + } + + rules, err = r.conn.GetRules(r.filterTable, chain) + if err != nil { + return err + } + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) { + err := r.conn.DelRule(rule) + if err != nil { + return err + } + } + } + r.isDefaultFwdRulesEnabled = false + return r.conn.Flush() +} + +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR +func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any { + ip, network, _ := net.ParseCIDR(cidr) + ipToAdd, _ := netip.AddrFromSlice(ip) + add := ipToAdd.Unmap() + + var offSet uint32 + if source { + offSet = 12 // src offset + } else { + offSet = 16 // dst offset + } + + return []expr.Any{ + // fetch src add + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offSet, + Len: 4, + }, + // net mask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: network.Mask, + Xor: zeroXor, + }, + // net address + &expr.Cmp{ + Register: 1, + Data: add.AsSlice(), + }, + } +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go new file mode 100644 index 000000000..aa1224a5a --- /dev/null +++ b/client/firewall/nftables/router_linux_test.go @@ -0,0 +1,280 @@ +//go:build !android + +package nftables + +import ( + "context" + "testing" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/stretchr/testify/require" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/test" +) + +const ( + // UNKNOWN is the default value for the firewall type for unknown firewall type + UNKNOWN = iota + // IPTABLES is the value for the iptables firewall type + IPTABLES + // NFTABLES is the value for the nftables firewall type + NFTABLES +) + +func TestNftablesManager_InsertRoutingRules(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this OS") + } + + table, err := createWorkTable() + if err != nil { + t.Fatal(err) + } + + defer deleteWorkTable() + + for _, testCase := range test.InsertRuleTestCases { + t.Run(testCase.Name, func(t *testing.T) { + manager, err := newRouter(context.TODO(), table) + require.NoError(t, err, "failed to create router") + + nftablesTestingClient := &nftables.Conn{} + + defer manager.ResetForwardRules() + + require.NoError(t, err, "shouldn't return error") + + err = manager.InsertRoutingRules(testCase.InputPair) + defer func() { + _ = manager.RemoveRoutingRules(testCase.InputPair) + }() + require.NoError(t, err, "forwarding pair should be inserted") + + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) + + found := 0 + for _, chain := range manager.chains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") + found = 1 + } + } + } + + require.Equal(t, 1, found, "should find at least 1 rule to test") + + if testCase.InputPair.Masquerade { + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + found := 0 + for _, chain := range manager.chains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") + found = 1 + } + } + } + require.Equal(t, 1, found, "should find at least 1 rule to test") + } + + sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) + destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) + testingExpression = append(sourceExp, destExp...) //nolint:gocritic + inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) + + found = 0 + for _, chain := range manager.chains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match") + found = 1 + } + } + } + + require.Equal(t, 1, found, "should find at least 1 rule to test") + + if testCase.InputPair.Masquerade { + inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + found := 0 + for _, chain := range manager.chains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey { + require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match") + found = 1 + } + } + } + require.Equal(t, 1, found, "should find at least 1 rule to test") + } + }) + } +} + +func TestNftablesManager_RemoveRoutingRules(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this OS") + } + + table, err := createWorkTable() + if err != nil { + t.Fatal(err) + } + + defer deleteWorkTable() + + for _, testCase := range test.RemoveRuleTestCases { + t.Run(testCase.Name, func(t *testing.T) { + manager, err := newRouter(context.TODO(), table) + require.NoError(t, err, "failed to create router") + + nftablesTestingClient := &nftables.Conn{} + + defer manager.ResetForwardRules() + + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + + forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic + forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) + insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.workTable, + Chain: manager.chains[chainNameRouteingFw], + Exprs: forwardExp, + UserData: []byte(forwardRuleKey), + }) + + natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + + insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.workTable, + Chain: manager.chains[chainNameRoutingNat], + Exprs: natExp, + UserData: []byte(natRuleKey), + }) + + sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) + destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) + + forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic + inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) + insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.workTable, + Chain: manager.chains[chainNameRouteingFw], + Exprs: forwardExp, + UserData: []byte(inForwardRuleKey), + }) + + natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic + inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + + insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ + Table: manager.workTable, + Chain: manager.chains[chainNameRoutingNat], + Exprs: natExp, + UserData: []byte(inNatRuleKey), + }) + + err = nftablesTestingClient.Flush() + require.NoError(t, err, "shouldn't return error") + + manager.ResetForwardRules() + + err = manager.RemoveRoutingRules(testCase.InputPair) + require.NoError(t, err, "shouldn't return error") + + for _, chain := range manager.chains { + rules, err := nftablesTestingClient.GetRules(chain.Table, chain) + require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) + for _, rule := range rules { + if len(rule.UserData) > 0 { + require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist") + require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") + require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") + require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") + } + } + } + }) + } +} + +// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. +func check() int { + nf := nftables.Conn{} + if _, err := nf.ListChains(); err == nil { + return NFTABLES + } + + ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return UNKNOWN + } + if isIptablesClientAvailable(ip) { + return IPTABLES + } + + return UNKNOWN +} + +func isIptablesClientAvailable(client *iptables.IPTables) bool { + _, err := client.ListChains("filter") + return err == nil +} + +func createWorkTable() (*nftables.Table, error) { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return nil, err + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, err + } + + for _, t := range tables { + if t.Name == tableName { + sConn.DelTable(t) + } + } + + table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + err = sConn.Flush() + + return table, err +} + +func deleteWorkTable() { + sConn, err := nftables.New(nftables.AsLasting()) + if err != nil { + return + } + + tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return + } + + for _, t := range tables { + if t.Name == tableName { + sConn.DelTable(t) + } + } +} diff --git a/client/firewall/nftables/rule_linux.go b/client/firewall/nftables/rule_linux.go index 98d1147cd..678c10b44 100644 --- a/client/firewall/nftables/rule_linux.go +++ b/client/firewall/nftables/rule_linux.go @@ -1,6 +1,8 @@ package nftables import ( + "net" + "github.com/google/nftables" ) @@ -8,9 +10,8 @@ import ( type Rule struct { nftRule *nftables.Rule nftSet *nftables.Set - - ruleID string - ip []byte + ruleID string + ip net.IP } // GetRuleID returns the rule id diff --git a/client/firewall/nftables/ruleset_linux.go b/client/firewall/nftables/ruleset_linux.go deleted file mode 100644 index 536a5ee18..000000000 --- a/client/firewall/nftables/ruleset_linux.go +++ /dev/null @@ -1,115 +0,0 @@ -package nftables - -import ( - "bytes" - "fmt" - - "github.com/google/nftables" - "github.com/rs/xid" -) - -// nftRuleset links native firewall rule and ipset to ACL generated rules -type nftRuleset struct { - nftRule *nftables.Rule - nftSet *nftables.Set - issuedRules map[string]*Rule - rulesetID string -} - -type rulesetManager struct { - rulesets map[string]*nftRuleset - - nftSetName2rulesetID map[string]string - issuedRuleID2rulesetID map[string]string -} - -func newRuleManager() *rulesetManager { - return &rulesetManager{ - rulesets: map[string]*nftRuleset{}, - - nftSetName2rulesetID: map[string]string{}, - issuedRuleID2rulesetID: map[string]string{}, - } -} - -func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) { - ruleset, ok := r.rulesets[rulesetID] - return ruleset, ok -} - -func (r *rulesetManager) createRuleset( - rulesetID string, - nftRule *nftables.Rule, - nftSet *nftables.Set, -) *nftRuleset { - ruleset := nftRuleset{ - rulesetID: rulesetID, - nftRule: nftRule, - nftSet: nftSet, - issuedRules: map[string]*Rule{}, - } - r.rulesets[ruleset.rulesetID] = &ruleset - if nftSet != nil { - r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID - } - return &ruleset -} - -func (r *rulesetManager) addRule( - ruleset *nftRuleset, - ip []byte, -) (*Rule, error) { - if _, ok := r.rulesets[ruleset.rulesetID]; !ok { - return nil, fmt.Errorf("ruleset not found") - } - - rule := Rule{ - nftRule: ruleset.nftRule, - nftSet: ruleset.nftSet, - ruleID: xid.New().String(), - ip: ip, - } - - ruleset.issuedRules[rule.ruleID] = &rule - r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID - - return &rule, nil -} - -// deleteRule from ruleset and returns true if contains other rules -func (r *rulesetManager) deleteRule(rule *Rule) bool { - rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID] - if !ok { - return false - } - - ruleset := r.rulesets[rulesetID] - if ruleset.nftRule == nil { - return false - } - delete(r.issuedRuleID2rulesetID, rule.ruleID) - delete(ruleset.issuedRules, rule.ruleID) - - if len(ruleset.issuedRules) == 0 { - delete(r.rulesets, ruleset.rulesetID) - if rule.nftSet != nil { - delete(r.nftSetName2rulesetID, rule.nftSet.Name) - } - return false - } - return true -} - -// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number -// -// This is important to do, because after we add rule to the nftables we can't update it until -// we set correct handle value to it. -func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error { - split := bytes.Split(nftRule.UserData, []byte(" ")) - ruleset, ok := r.rulesets[string(split[0])] - if !ok { - return fmt.Errorf("ruleset not found") - } - *ruleset.nftRule = *nftRule - return nil -} diff --git a/client/firewall/nftables/ruleset_linux_test.go b/client/firewall/nftables/ruleset_linux_test.go deleted file mode 100644 index 74b37d8f8..000000000 --- a/client/firewall/nftables/ruleset_linux_test.go +++ /dev/null @@ -1,122 +0,0 @@ -package nftables - -import ( - "testing" - - "github.com/google/nftables" - "github.com/stretchr/testify/require" -) - -func TestRulesetManager_createRuleset(t *testing.T) { - // Create a ruleset manager. - rulesetManager := newRuleManager() - - // Create a ruleset. - rulesetID := "ruleset-1" - nftRule := nftables.Rule{ - UserData: []byte(rulesetID), - } - ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) - require.NotNil(t, ruleset, "createRuleset() failed") - require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect") - require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect") -} - -func TestRulesetManager_addRule(t *testing.T) { - // Create a ruleset manager. - rulesetManager := newRuleManager() - - // Create a ruleset. - rulesetID := "ruleset-1" - nftRule := nftables.Rule{} - ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) - - // Add a rule to the ruleset. - ip := []byte("192.168.1.1") - rule, err := rulesetManager.addRule(ruleset, ip) - require.NoError(t, err, "addRule() failed") - require.NotNil(t, rule, "rule should not be nil") - require.NotEqual(t, rule.ruleID, "ruleID is empty") - require.EqualValues(t, rule.ip, ip, "ip is incorrect") - require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset") - require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager") - - ruleset2 := &nftRuleset{ - rulesetID: "ruleset-2", - } - _, err = rulesetManager.addRule(ruleset2, ip) - require.Error(t, err, "addRule() should have failed") -} - -func TestRulesetManager_deleteRule(t *testing.T) { - // Create a ruleset manager. - rulesetManager := newRuleManager() - - // Create a ruleset. - rulesetID := "ruleset-1" - nftRule := nftables.Rule{} - ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) - - // Add a rule to the ruleset. - ip := []byte("192.168.1.1") - rule, err := rulesetManager.addRule(ruleset, ip) - require.NoError(t, err, "addRule() failed") - require.NotNil(t, rule, "rule should not be nil") - - ip2 := []byte("192.168.1.1") - rule2, err := rulesetManager.addRule(ruleset, ip2) - require.NoError(t, err, "addRule() failed") - require.NotNil(t, rule2, "rule should not be nil") - - hasNext := rulesetManager.deleteRule(rule) - require.True(t, hasNext, "deleteRule() should have returned true") - - // Check that the rule is no longer in the manager. - require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted") - - hasNext = rulesetManager.deleteRule(rule2) - require.False(t, hasNext, "deleteRule() should have returned false") -} - -func TestRulesetManager_setNftRuleHandle(t *testing.T) { - // Create a ruleset manager. - rulesetManager := newRuleManager() - // Create a ruleset. - rulesetID := "ruleset-1" - nftRule := nftables.Rule{} - ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil) - // Add a rule to the ruleset. - ip := []byte("192.168.0.1") - - rule, err := rulesetManager.addRule(ruleset, ip) - require.NoError(t, err, "addRule() failed") - require.NotNil(t, rule, "rule should not be nil") - - nftRuleCopy := nftRule - nftRuleCopy.Handle = 2 - nftRuleCopy.UserData = []byte(rulesetID) - err = rulesetManager.setNftRuleHandle(&nftRuleCopy) - require.NoError(t, err, "setNftRuleHandle() failed") - // check correct work with references - require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect") -} - -func TestRulesetManager_getRuleset(t *testing.T) { - // Create a ruleset manager. - rulesetManager := newRuleManager() - // Create a ruleset. - rulesetID := "ruleset-1" - nftRule := nftables.Rule{} - nftSet := nftables.Set{ - ID: 2, - } - ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet) - require.NotNil(t, ruleset, "createRuleset() failed") - - find, ok := rulesetManager.getRuleset(rulesetID) - require.True(t, ok, "getRuleset() failed") - require.Equal(t, ruleset, find, "getRulesetBySetID() failed") - - _, ok = rulesetManager.getRuleset("does-not-exist") - require.False(t, ok, "getRuleset() failed") -} diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go new file mode 100644 index 000000000..432d113dd --- /dev/null +++ b/client/firewall/test/cases_linux.go @@ -0,0 +1,47 @@ +//go:build !android + +package test + +import firewall "github.com/netbirdio/netbird/client/firewall/manager" + +var ( + InsertRuleTestCases = []struct { + Name string + InputPair firewall.RouterPair + }{ + { + Name: "Insert Forwarding IPV4 Rule", + InputPair: firewall.RouterPair{ + ID: "zxa", + Source: "100.100.100.1/32", + Destination: "100.100.200.0/24", + Masquerade: false, + }, + }, + { + Name: "Insert Forwarding And Nat IPV4 Rules", + InputPair: firewall.RouterPair{ + ID: "zxa", + Source: "100.100.100.1/32", + Destination: "100.100.200.0/24", + Masquerade: true, + }, + }, + } + + RemoveRuleTestCases = []struct { + Name string + InputPair firewall.RouterPair + IpVersion string + }{ + { + Name: "Remove Forwarding And Nat IPV4 Rules", + InputPair: firewall.RouterPair{ + ID: "zxa", + Source: "100.100.100.1/32", + Destination: "100.100.200.0/24", + Masquerade: true, + }, + }, + } +) diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index ccfef1861..2275dad39 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -1,4 +1,4 @@ -//go:build !windows && !linux +//go:build !windows package uspfilter @@ -10,10 +10,16 @@ func (m *Manager) Reset() error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.nativeFirewall != nil { + return m.nativeFirewall.Reset() + } return nil } // AllowNetbird allows netbird interface traffic func (m *Manager) AllowNetbird() error { + if m.nativeFirewall != nil { + return m.nativeFirewall.AllowNetbird() + } return nil } diff --git a/client/firewall/uspfilter/allow_netbird_linux.go b/client/firewall/uspfilter/allow_netbird_linux.go deleted file mode 100644 index 5df48c756..000000000 --- a/client/firewall/uspfilter/allow_netbird_linux.go +++ /dev/null @@ -1,21 +0,0 @@ -package uspfilter - -// AllowNetbird allows netbird interface traffic -func (m *Manager) AllowNetbird() error { - return nil -} - -// Reset firewall to the default state -func (m *Manager) Reset() error { - m.mutex.Lock() - defer m.mutex.Unlock() - - m.outgoingRules = make(map[string]RuleSet) - m.incomingRules = make(map[string]RuleSet) - - if m.resetHook != nil { - return m.resetHook() - } - - return nil -} diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index 40872f67d..5c1daccaf 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -5,7 +5,7 @@ import ( "github.com/google/gopacket" - fw "github.com/netbirdio/netbird/client/firewall" + firewall "github.com/netbirdio/netbird/client/firewall/manager" ) // Rule to handle management of rules @@ -15,7 +15,7 @@ type Rule struct { ipLayer gopacket.LayerType matchByIP bool protoLayer gopacket.LayerType - direction fw.RuleDirection + direction firewall.RuleDirection sPort uint16 dPort uint16 drop bool diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 7119e791c..427a73825 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -10,12 +10,16 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - fw "github.com/netbirdio/netbird/client/firewall" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" ) const layerTypeAll = 0 +var ( + errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") +) + // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { SetFilter(iface.PacketFilter) error @@ -27,12 +31,12 @@ type RuleSet map[string]Rule // Manager userspace firewall manager type Manager struct { - outgoingRules map[string]RuleSet - incomingRules map[string]RuleSet - wgNetwork *net.IPNet - decoders sync.Pool - wgIface IFaceMapper - resetHook func() error + outgoingRules map[string]RuleSet + incomingRules map[string]RuleSet + wgNetwork *net.IPNet + decoders sync.Pool + wgIface IFaceMapper + nativeFirewall firewall.Manager mutex sync.RWMutex } @@ -52,6 +56,20 @@ type decoder struct { // Create userspace firewall manager constructor func Create(iface IFaceMapper) (*Manager, error) { + return create(iface) +} + +func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { + mgr, err := create(iface) + if err != nil { + return nil, err + } + + mgr.nativeFirewall = nativeFirewall + return mgr, nil +} + +func create(iface IFaceMapper) (*Manager, error) { m := &Manager{ decoders: sync.Pool{ New: func() any { @@ -77,27 +95,50 @@ func Create(iface IFaceMapper) (*Manager, error) { return m, nil } +func (m *Manager) IsServerRouteSupported() bool { + if m.nativeFirewall == nil { + return false + } else { + return true + } +} + +func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.InsertRoutingRules(pair) +} + +// RemoveRoutingRules removes a routing firewall rule +func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.RemoveRoutingRules(pair) +} + // AddFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule func (m *Manager) AddFiltering( ip net.IP, - proto fw.Protocol, - sPort *fw.Port, - dPort *fw.Port, - direction fw.RuleDirection, - action fw.Action, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + direction firewall.RuleDirection, + action firewall.Action, ipsetName string, comment string, -) (fw.Rule, error) { +) ([]firewall.Rule, error) { r := Rule{ id: uuid.New().String(), ip: ip, ipLayer: layers.LayerTypeIPv6, matchByIP: true, direction: direction, - drop: action == fw.ActionDrop, + drop: action == firewall.ActionDrop, comment: comment, } if ipNormalized := ip.To4(); ipNormalized != nil { @@ -118,21 +159,21 @@ func (m *Manager) AddFiltering( } switch proto { - case fw.ProtocolTCP: + case firewall.ProtocolTCP: r.protoLayer = layers.LayerTypeTCP - case fw.ProtocolUDP: + case firewall.ProtocolUDP: r.protoLayer = layers.LayerTypeUDP - case fw.ProtocolICMP: + case firewall.ProtocolICMP: r.protoLayer = layers.LayerTypeICMPv4 if r.ipLayer == layers.LayerTypeIPv6 { r.protoLayer = layers.LayerTypeICMPv6 } - case fw.ProtocolALL: + case firewall.ProtocolALL: r.protoLayer = layerTypeAll } m.mutex.Lock() - if direction == fw.RuleDirectionIN { + if direction == firewall.RuleDirectionIN { if _, ok := m.incomingRules[r.ip.String()]; !ok { m.incomingRules[r.ip.String()] = make(RuleSet) } @@ -144,12 +185,11 @@ func (m *Manager) AddFiltering( m.outgoingRules[r.ip.String()][r.id] = r } m.mutex.Unlock() - - return &r, nil + return []firewall.Rule{&r}, nil } // DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule fw.Rule) error { +func (m *Manager) DeleteRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -158,7 +198,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } - if r.direction == fw.RuleDirectionIN { + if r.direction == firewall.RuleDirectionIN { _, ok := m.incomingRules[r.ip.String()][r.id] if !ok { return fmt.Errorf("delete rule: no rule with such id: %v", r.id) @@ -322,7 +362,7 @@ func (m *Manager) AddUDPPacketHook( protoLayer: layers.LayerTypeUDP, dPort: dPort, ipLayer: layers.LayerTypeIPv6, - direction: fw.RuleDirectionOUT, + direction: firewall.RuleDirectionOUT, comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), udpHook: hook, } @@ -333,7 +373,7 @@ func (m *Manager) AddUDPPacketHook( m.mutex.Lock() if in { - r.direction = fw.RuleDirectionIN + r.direction = firewall.RuleDirectionIN if _, ok := m.incomingRules[r.ip.String()]; !ok { m.incomingRules[r.ip.String()] = make(map[string]Rule) } @@ -370,8 +410,3 @@ func (m *Manager) RemovePacketHook(hookID string) error { } return fmt.Errorf("hook with given id not found") } - -// SetResetHook which will be executed in the end of Reset method -func (m *Manager) SetResetHook(hook func() error) { - m.resetHook = hook -} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 6b3d334a8..514a90539 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -10,7 +10,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" - fw "github.com/netbirdio/netbird/client/firewall" + fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" ) @@ -125,24 +125,32 @@ func TestManagerDeleteRule(t *testing.T) { return } - err = m.DeleteRule(rule) - if err != nil { - t.Errorf("failed to delete rule: %v", err) - return + for _, r := range rule { + err = m.DeleteRule(r) + if err != nil { + t.Errorf("failed to delete rule: %v", err) + return + } } - if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok { - t.Errorf("rule2 is not in the incomingRules") + for _, r := range rule2 { + if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { + t.Errorf("rule2 is not in the incomingRules") + } } - err = m.DeleteRule(rule2) - if err != nil { - t.Errorf("failed to delete rule: %v", err) - return + for _, r := range rule2 { + err = m.DeleteRule(r) + if err != nil { + t.Errorf("failed to delete rule: %v", err) + return + } } - if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok { - t.Errorf("rule2 is not in the incomingRules") + for _, r := range rule2 { + if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok { + t.Errorf("rule2 is not in the incomingRules") + } } } diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index feaaa7b8b..fd2c2c875 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -11,42 +11,27 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/firewall" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) -// IFaceMapper defines subset methods of interface required for manager -type IFaceMapper interface { - Name() string - Address() iface.WGAddress - IsUserspaceBind() bool - SetFilter(iface.PacketFilter) error -} - // Manager is a ACL rules manager type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap) - Stop() } // DefaultManager uses firewall manager to handle type DefaultManager struct { - manager firewall.Manager + firewall firewall.Manager ipsetCounter int rulesPairs map[string][]firewall.Rule mutex sync.Mutex } -type ipsetInfo struct { - name string - ipCount int -} - -func newDefaultManager(fm firewall.Manager) *DefaultManager { +func NewDefaultManager(fm firewall.Manager) *DefaultManager { return &DefaultManager{ - manager: fm, + firewall: fm, rulesPairs: make(map[string][]firewall.Rule), } } @@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { time.Since(start), total) }() - if d.manager == nil { + if d.firewall == nil { log.Debug("firewall manager is not supported, skipping firewall rules") return } defer func() { - if err := d.manager.Flush(); err != nil { + if err := d.firewall.Flush(); err != nil { log.Error("failed to flush firewall rules: ", err) } }() @@ -125,57 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { ) } - applyFailed := false newRulePairs := make(map[string][]firewall.Rule) - ipsetByRuleSelectors := make(map[string]*ipsetInfo) - - // calculate which IP's can be grouped in by which ipset - // to do that we use rule selector (which is just rule properties without IP's) - for _, r := range rules { - selector := d.getRuleGroupingSelector(r) - ipset, ok := ipsetByRuleSelectors[selector] - if !ok { - ipset = &ipsetInfo{} - } - - ipset.ipCount++ - ipsetByRuleSelectors[selector] = ipset - } + ipsetByRuleSelectors := make(map[string]string) for _, r := range rules { // if this rule is member of rule selection with more than DefaultIPsCountForSet // it's IP address can be used in the ipset for firewall manager which supports it - ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)] - if ipset.name == "" { + selector := d.getRuleGroupingSelector(r) + ipsetName, ok := ipsetByRuleSelectors[selector] + if !ok { d.ipsetCounter++ - ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter) + ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter) + ipsetByRuleSelectors[selector] = ipsetName } - ipsetName := ipset.name pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) if err != nil { log.Errorf("failed to apply firewall rule: %+v, %v", r, err) - applyFailed = true + d.rollBack(newRulePairs) break } - newRulePairs[pairID] = rulePair - } - if applyFailed { - log.Error("failed to apply firewall rules, rollback ACL to previous state") - for _, rules := range newRulePairs { - for _, rule := range rules { - if err := d.manager.DeleteRule(rule); err != nil { - log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) - continue - } - } + if len(rules) > 0 { + d.rulesPairs[pairID] = rulePair + newRulePairs[pairID] = rulePair } - return } for pairID, rules := range d.rulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { - if err := d.manager.DeleteRule(rule); err != nil { + if err := d.firewall.DeleteRule(rule); err != nil { log.Errorf("failed to delete firewall rule: %v", err) continue } @@ -186,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { d.rulesPairs = newRulePairs } -// Stop ACL controller and clear firewall state -func (d *DefaultManager) Stop() { - d.mutex.Lock() - defer d.mutex.Unlock() - - if err := d.manager.Reset(); err != nil { - log.WithError(err).Error("reset firewall state") - } -} - func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, @@ -205,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule( return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") } - protocol := convertToFirewallProtocol(r.Protocol) - if protocol == firewall.ProtocolUnknown { - return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol) + protocol, err := convertToFirewallProtocol(r.Protocol) + if err != nil { + return "", nil, fmt.Errorf("skipping firewall rule: %s", err) } - action := convertFirewallAction(r.Action) - if action == firewall.ActionUnknown { - return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action) + action, err := convertFirewallAction(r.Action) + if err != nil { + return "", nil, fmt.Errorf("skipping firewall rule: %s", err) } var port *firewall.Port @@ -232,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule( } var rules []firewall.Rule - var err error switch r.Direction { case mgmProto.FirewallRule_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") @@ -246,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule( return "", nil, err } - d.rulesPairs[ruleID] = rules return ruleID, rules, nil } @@ -259,24 +210,24 @@ func (d *DefaultManager) addInRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.manager.AddFiltering( + rule, err := d.firewall.AddFiltering( ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } - rules = append(rules, rule) + rules = append(rules, rule...) if shouldSkipInvertedRule(protocol, port) { return rules, nil } - rule, err = d.manager.AddFiltering( + rule, err = d.firewall.AddFiltering( ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } - return append(rules, rule), nil + return append(rules, rule...), nil } func (d *DefaultManager) addOutRules( @@ -288,24 +239,24 @@ func (d *DefaultManager) addOutRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.manager.AddFiltering( + rule, err := d.firewall.AddFiltering( ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } - rules = append(rules, rule) + rules = append(rules, rule...) if shouldSkipInvertedRule(protocol, port) { return rules, nil } - rule, err = d.manager.AddFiltering( + rule, err = d.firewall.AddFiltering( ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) } - return append(rules, rule), nil + return append(rules, rule...), nil } // getRuleID() returns unique ID for the rule based on its parameters. @@ -461,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) } -func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol { +func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) { + log.Debugf("rollback ACL to previous state") + for _, rules := range newRulePairs { + for _, rule := range rules { + if err := d.firewall.DeleteRule(rule); err != nil { + log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) + } + } + } +} + +func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) { switch protocol { case mgmProto.FirewallRule_TCP: - return firewall.ProtocolTCP + return firewall.ProtocolTCP, nil case mgmProto.FirewallRule_UDP: - return firewall.ProtocolUDP + return firewall.ProtocolUDP, nil case mgmProto.FirewallRule_ICMP: - return firewall.ProtocolICMP + return firewall.ProtocolICMP, nil case mgmProto.FirewallRule_ALL: - return firewall.ProtocolALL + return firewall.ProtocolALL, nil default: - return firewall.ProtocolUnknown + return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) } } @@ -480,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil } -func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action { +func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) { switch action { case mgmProto.FirewallRule_ACCEPT: - return firewall.ActionAccept + return firewall.ActionAccept, nil case mgmProto.FirewallRule_DROP: - return firewall.ActionDrop + return firewall.ActionDrop, nil default: - return firewall.ActionUnknown + return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) } } diff --git a/client/internal/acl/manager_create.go b/client/internal/acl/manager_create.go deleted file mode 100644 index 66185749b..000000000 --- a/client/internal/acl/manager_create.go +++ /dev/null @@ -1,28 +0,0 @@ -//go:build !linux || android - -package acl - -import ( - "fmt" - "runtime" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/firewall/uspfilter" -) - -// Create creates a firewall manager instance -func Create(iface IFaceMapper) (manager *DefaultManager, err error) { - if iface.IsUserspaceBind() { - // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface) - if err != nil { - return nil, err - } - if err := fm.AllowNetbird(); err != nil { - log.Warnf("failed to allow netbird interface traffic: %v", err) - } - return newDefaultManager(fm), nil - } - return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) -} diff --git a/client/internal/acl/manager_create_linux.go b/client/internal/acl/manager_create_linux.go deleted file mode 100644 index 05b042351..000000000 --- a/client/internal/acl/manager_create_linux.go +++ /dev/null @@ -1,77 +0,0 @@ -//go:build !android - -package acl - -import ( - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/firewall" - "github.com/netbirdio/netbird/client/firewall/iptables" - "github.com/netbirdio/netbird/client/firewall/nftables" - "github.com/netbirdio/netbird/client/firewall/uspfilter" - "github.com/netbirdio/netbird/client/internal/checkfw" -) - -// Create creates a firewall manager instance for the Linux -func Create(iface IFaceMapper) (*DefaultManager, error) { - // on the linux system we try to user nftables or iptables - // in any case, because we need to allow netbird interface traffic - // so we use AllowNetbird traffic from these firewall managers - // for the userspace packet filtering firewall - var fm firewall.Manager - var err error - - checkResult := checkfw.Check() - switch checkResult { - case checkfw.IPTABLES, checkfw.IPTABLESWITHV6: - log.Debug("creating an iptables firewall manager for access control") - ipv6Supported := checkResult == checkfw.IPTABLESWITHV6 - if fm, err = iptables.Create(iface, ipv6Supported); err != nil { - log.Infof("failed to create iptables manager for access control: %s", err) - } - case checkfw.NFTABLES: - log.Debug("creating an nftables firewall manager for access control") - if fm, err = nftables.Create(iface); err != nil { - log.Debugf("failed to create nftables manager for access control: %s", err) - } - } - - var resetHookForUserspace func() error - if fm != nil && err == nil { - // err shadowing is used here, to ignore this error - if err := fm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) - } - resetHookForUserspace = fm.Reset - } - - if iface.IsUserspaceBind() { - // use userspace packet filtering firewall - usfm, err := uspfilter.Create(iface) - if err != nil { - log.Debugf("failed to create userspace filtering firewall: %s", err) - return nil, err - } - - // set kernel space firewall Reset as hook for userspace firewall - // manager Reset method, to clean up - if resetHookForUserspace != nil { - usfm.SetResetHook(resetHookForUserspace) - } - - // to be consistent for any future extensions. - // ignore this error - if err := usfm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) - } - fm = usfm - } - - if fm == nil || err != nil { - log.Errorf("failed to create firewall manager: %s", err) - // no firewall manager found or initialized correctly - return nil, err - } - - return newDefaultManager(fm), nil -} diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index d55a1cad6..5e3db0a24 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,11 +1,14 @@ package acl import ( + "context" "net" "testing" "github.com/golang/mock/gomock" + "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -49,12 +52,15 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - acl, err := Create(ifaceMock) + fw, err := firewall.NewFirewall(context.Background(), ifaceMock) if err != nil { - t.Errorf("create ACL manager: %v", err) + t.Errorf("create firewall: %v", err) return } - defer acl.Stop() + defer func(fw manager.Manager) { + _ = fw.Reset() + }(fw) + acl := NewDefaultManager(fw) t.Run("apply firewall rules", func(t *testing.T) { acl.ApplyFiltering(networkMap) @@ -339,12 +345,15 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - acl, err := Create(ifaceMock) + fw, err := firewall.NewFirewall(context.Background(), ifaceMock) if err != nil { - t.Errorf("create ACL manager: %v", err) + t.Errorf("create firewall: %v", err) return } - defer acl.Stop() + defer func(fw manager.Manager) { + _ = fw.Reset() + }(fw) + acl := NewDefaultManager(fw) acl.ApplyFiltering(networkMap) diff --git a/client/internal/checkfw/check.go b/client/internal/checkfw/check.go deleted file mode 100644 index edfd8a5b3..000000000 --- a/client/internal/checkfw/check.go +++ /dev/null @@ -1,3 +0,0 @@ -//go:build !linux || android - -package checkfw diff --git a/client/internal/checkfw/check_linux.go b/client/internal/checkfw/check_linux.go deleted file mode 100644 index 552d5698c..000000000 --- a/client/internal/checkfw/check_linux.go +++ /dev/null @@ -1,56 +0,0 @@ -//go:build !android - -package checkfw - -import ( - "os" - - "github.com/coreos/go-iptables/iptables" - "github.com/google/nftables" -) - -const ( - // UNKNOWN is the default value for the firewall type for unknown firewall type - UNKNOWN FWType = iota - // IPTABLES is the value for the iptables firewall type - IPTABLES - // IPTABLESWITHV6 is the value for the iptables firewall type with ipv6 - IPTABLESWITHV6 - // NFTABLES is the value for the nftables firewall type - NFTABLES -) - -// SKIP_NFTABLES_ENV is the environment variable to skip nftables check -const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" - -// FWType is the type for the firewall type -type FWType int - -// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. -func Check() FWType { - nf := nftables.Conn{} - if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { - return NFTABLES - } - - ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) - if err == nil { - if isIptablesClientAvailable(ip) { - ipSupport := IPTABLES - ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6) - if ip6Err == nil { - if isIptablesClientAvailable(ipv6) { - ipSupport = IPTABLESWITHV6 - } - } - return ipSupport - } - } - - return UNKNOWN -} - -func isIptablesClientAvailable(client *iptables.IPTables) bool { - _, err := client.ListChains("filter") - return err == nil -} diff --git a/client/internal/dns/host_linux.go b/client/internal/dns/host_linux.go index 7838c988f..a6557f121 100644 --- a/client/internal/dns/host_linux.go +++ b/client/internal/dns/host_linux.go @@ -25,13 +25,30 @@ const ( type osManagerType int +func (t osManagerType) String() string { + switch t { + case netbirdManager: + return "netbird" + case fileManager: + return "file" + case networkManager: + return "networkManager" + case systemdManager: + return "systemd" + case resolvConfManager: + return "resolvconf" + default: + return "unknown" + } +} + func newHostManager(wgInterface WGIface) (hostManager, error) { osManager, err := getOSDNSManagerType() if err != nil { return nil, err } - log.Debugf("discovered mode is: %d", osManager) + log.Debugf("discovered mode is: %s", osManager) switch osManager { case networkManager: return newNetworkManagerDbusConfigurator(wgInterface) @@ -65,7 +82,6 @@ func getOSDNSManagerType() (osManagerType, error) { return netbirdManager, nil } if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { - log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion()) return networkManager, nil } if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { diff --git a/client/internal/engine.go b/client/internal/engine.go index 4d461b746..c525601b4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -17,6 +17,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" @@ -115,6 +117,7 @@ type Engine struct { statusRecorder *peer.Status + firewall manager.Manager routeManager routemanager.Manager acl acl.Manager @@ -231,6 +234,19 @@ func (e *Engine) Start() error { return err } + e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) + if err != nil { + log.Errorf("failed creating firewall manager: %s", err) + } + + if e.firewall != nil && e.firewall.IsServerRouteSupported() { + err = e.routeManager.EnableServerRouter(e.firewall) + if err != nil { + e.close() + return err + } + } + err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort) if err != nil { log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error()) @@ -258,10 +274,8 @@ func (e *Engine) Start() error { e.udpMux = mux } - if acl, err := acl.Create(e.wgInterface); err != nil { - log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error()) - } else { - e.acl = acl + if e.firewall != nil { + e.acl = acl.NewDefaultManager(e.firewall) } err = e.dnsServer.Initialize() @@ -1044,8 +1058,11 @@ func (e *Engine) close() { e.dnsServer.Stop() } - if e.acl != nil { - e.acl.Stop() + if e.firewall != nil { + err := e.firewall.Reset() + if err != nil { + log.Warnf("failed to reset firewall: %s", err) + } } } diff --git a/client/internal/routemanager/common_linux_test.go b/client/internal/routemanager/common_linux_test.go deleted file mode 100644 index d27f532cd..000000000 --- a/client/internal/routemanager/common_linux_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package routemanager - -var insertRuleTestCases = []struct { - name string - inputPair routerPair - ipVersion string -}{ - { - name: "Insert Forwarding IPV4 Rule", - inputPair: routerPair{ - ID: "zxa", - source: "100.100.100.1/32", - destination: "100.100.200.0/24", - masquerade: false, - }, - ipVersion: ipv4, - }, - { - name: "Insert Forwarding And Nat IPV4 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "100.100.100.1/32", - destination: "100.100.200.0/24", - masquerade: true, - }, - ipVersion: ipv4, - }, - { - name: "Insert Forwarding IPV6 Rule", - inputPair: routerPair{ - ID: "zxa", - source: "fc00::1/128", - destination: "fc12::/64", - masquerade: false, - }, - ipVersion: ipv6, - }, - { - name: "Insert Forwarding And Nat IPV6 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "fc00::1/128", - destination: "fc12::/64", - masquerade: true, - }, - ipVersion: ipv6, - }, -} - -var removeRuleTestCases = []struct { - name string - inputPair routerPair - ipVersion string -}{ - { - name: "Remove Forwarding And Nat IPV4 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "100.100.100.1/32", - destination: "100.100.200.0/24", - masquerade: true, - }, - ipVersion: ipv4, - }, - { - name: "Remove Forwarding And Nat IPV6 Rules", - inputPair: routerPair{ - ID: "zxa", - source: "fc00::1/128", - destination: "fc12::/64", - masquerade: true, - }, - ipVersion: ipv6, - }, -} diff --git a/client/internal/routemanager/firewall.go b/client/internal/routemanager/firewall.go deleted file mode 100644 index fc6ff58f1..000000000 --- a/client/internal/routemanager/firewall.go +++ /dev/null @@ -1,12 +0,0 @@ -package routemanager - -type firewallManager interface { - // RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules - RestoreOrCreateContainers() error - // InsertRoutingRules inserts a routing firewall rule - InsertRoutingRules(pair routerPair) error - // RemoveRoutingRules removes a routing firewall rule - RemoveRoutingRules(pair routerPair) error - // CleanRoutingRules cleans a firewall set of containers - CleanRoutingRules() -} diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go deleted file mode 100644 index 50d451a88..000000000 --- a/client/internal/routemanager/firewall_linux.go +++ /dev/null @@ -1,55 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "context" - "fmt" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/checkfw" -) - -const ( - ipv6Forwarding = "netbird-rt-ipv6-forwarding" - ipv4Forwarding = "netbird-rt-ipv4-forwarding" - ipv6Nat = "netbird-rt-ipv6-nat" - ipv4Nat = "netbird-rt-ipv4-nat" - natFormat = "netbird-nat-%s" - forwardingFormat = "netbird-fwd-%s" - inNatFormat = "netbird-nat-in-%s" - inForwardingFormat = "netbird-fwd-in-%s" - ipv6 = "ipv6" - ipv4 = "ipv4" -) - -func genKey(format string, input string) string { - return fmt.Sprintf(format, input) -} - -// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager -func newFirewall(parentCTX context.Context) (firewallManager, error) { - checkResult := checkfw.Check() - switch checkResult { - case checkfw.IPTABLES, checkfw.IPTABLESWITHV6: - log.Debug("creating an iptables firewall manager for route rules") - ipv6Supported := checkResult == checkfw.IPTABLESWITHV6 - return newIptablesManager(parentCTX, ipv6Supported) - case checkfw.NFTABLES: - log.Info("creating an nftables firewall manager for route rules") - return newNFTablesManager(parentCTX), nil - } - - return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules") -} - -func getInPair(pair routerPair) routerPair { - return routerPair{ - ID: pair.ID, - // invert source/destination - source: pair.destination, - destination: pair.source, - masquerade: pair.masquerade, - } -} diff --git a/client/internal/routemanager/firewall_nonlinux.go b/client/internal/routemanager/firewall_nonlinux.go deleted file mode 100644 index ae0627048..000000000 --- a/client/internal/routemanager/firewall_nonlinux.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build !linux -// +build !linux - -package routemanager - -import ( - "context" - "fmt" - "runtime" -) - -// newFirewall returns a nil manager -func newFirewall(context.Context) (firewallManager, error) { - return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS) -} diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go deleted file mode 100644 index e9fbb7d3c..000000000 --- a/client/internal/routemanager/iptables_linux.go +++ /dev/null @@ -1,487 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "context" - "fmt" - "net/netip" - "os/exec" - "strings" - "sync" - - "github.com/coreos/go-iptables/iptables" - log "github.com/sirupsen/logrus" -) - -func isIptablesSupported() bool { - _, err4 := exec.LookPath("iptables") - _, err6 := exec.LookPath("ip6tables") - return err4 == nil && err6 == nil -} - -// constants needed to manage and create iptable rules -const ( - iptablesFilterTable = "filter" - iptablesNatTable = "nat" - iptablesForwardChain = "FORWARD" - iptablesPostRoutingChain = "POSTROUTING" - iptablesRoutingNatChain = "NETBIRD-RT-NAT" - iptablesRoutingForwardingChain = "NETBIRD-RT-FWD" - routingFinalForwardJump = "ACCEPT" - routingFinalNatJump = "MASQUERADE" -) - -// some presets for building nftable rules -var ( - iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"} - iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"} - iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"} - iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"} -) - -type iptablesManager struct { - ctx context.Context - stop context.CancelFunc - ipv4Client *iptables.IPTables - ipv6Client *iptables.IPTables - rules map[string]map[string][]string - mux sync.Mutex -} - -func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) { - ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) - if err != nil { - return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err) - } - - ctx, cancel := context.WithCancel(parentCtx) - manager := &iptablesManager{ - ctx: ctx, - stop: cancel, - ipv4Client: ipv4Client, - rules: make(map[string]map[string][]string), - } - - if ipv6Supported { - manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err) - } - } - - return manager, nil -} - -// CleanRoutingRules cleans existing iptables resources that we created by the agent -func (i *iptablesManager) CleanRoutingRules() { - i.mux.Lock() - defer i.mux.Unlock() - - err := i.cleanJumpRules() - if err != nil { - log.Error(err) - } - - log.Debug("flushing tables") - errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v" - if i.ipv4Client != nil { - err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) - if err != nil { - log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) - } - - err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) - if err != nil { - log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) - } - } - - if i.ipv6Client != nil { - err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain) - if err != nil { - log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) - } - - err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain) - if err != nil { - log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) - } - } - - log.Info("done cleaning up iptables rules") -} - -// RestoreOrCreateContainers restores existing iptables containers (chains and rules) -// if they don't exist, we create them -func (i *iptablesManager) RestoreOrCreateContainers() error { - i.mux.Lock() - defer i.mux.Unlock() - - if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil { - return nil - } - - errMSGFormat := "iptables: failed creating %s chain %s,error: %v" - - if i.ipv4Client != nil { - err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err) - } - - err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err) - } - - err = i.restoreRules(i.ipv4Client) - if err != nil { - return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err) - } - } - - if i.ipv6Client != nil { - err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err) - } - - err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err) - } - - err = i.restoreRules(i.ipv6Client) - if err != nil { - return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err) - } - } - - err := i.addJumpRules() - if err != nil { - return fmt.Errorf("iptables: error while creating jump rules: %v", err) - } - - return nil -} - -// addJumpRules create jump rules to send packets to NetBird chains -func (i *iptablesManager) addJumpRules() error { - err := i.cleanJumpRules() - if err != nil { - return err - } - if i.ipv4Client != nil { - rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) //nolint:gocritic - - err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) - if err != nil { - return err - } - i.rules[ipv4][ipv4Forwarding] = rule - - rule = append(iptablesDefaultNatRule, ipv4Nat) //nolint:gocritic - err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) - if err != nil { - return err - } - i.rules[ipv4][ipv4Nat] = rule - } - - if i.ipv6Client != nil { - rule := append(iptablesDefaultForwardingRule, ipv6Forwarding) //nolint:gocritic - err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...) - if err != nil { - return err - } - i.rules[ipv6][ipv6Forwarding] = rule - - rule = append(iptablesDefaultNatRule, ipv6Nat) //nolint:gocritic - err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...) - if err != nil { - return err - } - i.rules[ipv6][ipv6Nat] = rule - } - - return nil -} - -// cleanJumpRules cleans jump rules that was sending packets to NetBird chains -func (i *iptablesManager) cleanJumpRules() error { - var err error - errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v" - rule, found := i.rules[ipv4][ipv4Forwarding] - if i.ipv4Client != nil { - if found { - log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding) - err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err) - } - } - rule, found = i.rules[ipv4][ipv4Nat] - if found { - log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat) - err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err) - } - } - } - if i.ipv6Client == nil { - rule, found = i.rules[ipv6][ipv6Forwarding] - if found { - log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding) - err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err) - } - } - rule, found = i.rules[ipv6][ipv6Nat] - if found { - log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat) - err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err) - } - } - } - return nil -} - -func iptablesProtoToString(proto iptables.Protocol) string { - if proto == iptables.ProtocolIPv6 { - return ipv6 - } - return ipv4 -} - -// restoreRules restores existing NetBird rules -func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error { - ipVersion := iptablesProtoToString(iptablesClient.Proto()) - - if i.rules[ipVersion] == nil { - i.rules[ipVersion] = make(map[string][]string) - } - table := iptablesFilterTable - for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} { - rules, err := iptablesClient.List(table, chain) - if err != nil { - return err - } - for _, ruleString := range rules { - rule := strings.Fields(ruleString) - id := getRuleRouteID(rule) - if id != "" { - i.rules[ipVersion][id] = rule[2:] - } - } - } - - table = iptablesNatTable - for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} { - rules, err := iptablesClient.List(table, chain) - if err != nil { - return err - } - for _, ruleString := range rules { - rule := strings.Fields(ruleString) - id := getRuleRouteID(rule) - if id != "" { - i.rules[ipVersion][id] = rule[2:] - } - } - } - - return nil -} - -// createChain create NetBird chains -func createChain(iptables *iptables.IPTables, table, newChain string) error { - chains, err := iptables.ListChains(table) - if err != nil { - return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err) - } - - shouldCreateChain := true - for _, chain := range chains { - if chain == newChain { - shouldCreateChain = false - } - } - - if shouldCreateChain { - err = iptables.NewChain(table, newChain) - if err != nil { - return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err) - } - - if table == iptablesNatTable { - err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...) - } else { - err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...) - } - if err != nil { - return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err) - } - - } - return nil -} - -// genRuleSpec generates rule specification with comment identifier -func genRuleSpec(jump, id, source, destination string) []string { - return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id} -} - -// getRuleRouteID returns the rule ID if matches our prefix -func getRuleRouteID(rule []string) string { - for i, flag := range rule { - if flag == "--comment" { - id := rule[i+1] - if strings.HasPrefix(id, "netbird-") { - return id - } - } - } - return "" -} - -// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain -func (i *iptablesManager) InsertRoutingRules(pair routerPair) error { - i.mux.Lock() - defer i.mux.Unlock() - - err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair) - if err != nil { - return err - } - - err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair)) - if err != nil { - return err - } - - if !pair.masquerade { - return nil - } - - err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair) - if err != nil { - return err - } - - err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair)) - if err != nil { - return err - } - - return nil -} - -// insertRoutingRule inserts an iptable rule -func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error { - var err error - - prefix := netip.MustParsePrefix(pair.source) - ipVersion := ipv4 - iptablesClient := i.ipv4Client - if prefix.Addr().Unmap().Is6() { - iptablesClient = i.ipv6Client - ipVersion = ipv6 - } - - if iptablesClient == nil { - return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized") - } - - ruleKey := genKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination) - existingRule, found := i.rules[ipVersion][ruleKey] - if found { - err = iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err) - } - delete(i.rules[ipVersion], ruleKey) - } - err = iptablesClient.Insert(table, chain, 1, rule...) - if err != nil { - return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err) - } - - i.rules[ipVersion][ruleKey] = rule - - return nil -} - -// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains -func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error { - i.mux.Lock() - defer i.mux.Unlock() - - err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair)) - if err != nil { - return err - } - - if !pair.masquerade { - return nil - } - - err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair)) - if err != nil { - return err - } - - return nil -} - -// removeRoutingRule removes an iptables rule -func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error { - var err error - - prefix := netip.MustParsePrefix(pair.source) - ipVersion := ipv4 - iptablesClient := i.ipv4Client - if prefix.Addr().Unmap().Is6() { - iptablesClient = i.ipv6Client - ipVersion = ipv6 - } - - if iptablesClient == nil { - return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized") - } - - ruleKey := genKey(keyFormat, pair.ID) - existingRule, found := i.rules[ipVersion][ruleKey] - if found { - err = iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err) - } - } - delete(i.rules[ipVersion], ruleKey) - - return nil -} - -func getIptablesRuleType(table string) string { - ruleType := "forwarding" - if table == iptablesNatTable { - ruleType = "nat" - } - return ruleType -} diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go deleted file mode 100644 index 4f733de34..000000000 --- a/client/internal/routemanager/iptables_linux_test.go +++ /dev/null @@ -1,294 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "context" - "testing" - - "github.com/coreos/go-iptables/iptables" - "github.com/stretchr/testify/require" -) - -func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { - - if !isIptablesSupported() { - t.SkipNow() - } - - manager, err := newIptablesManager(context.TODO(), true) - require.NoError(t, err, "should return a valid iptables manager") - - defer manager.CleanRoutingRules() - - err = manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6") - - require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4") - - exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain) - require.True(t, exists, "forwarding rule should exist") - - exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain) - require.True(t, exists, "postrouting rule should exist") - - require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6") - - exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain) - require.True(t, exists, "forwarding rule should exist") - - exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain) - require.True(t, exists, "postrouting rule should exist") - - pair := routerPair{ - ID: "abc", - source: "100.100.100.1/32", - destination: "100.100.100.0/24", - masquerade: true, - } - forward4RuleKey := genKey(forwardingFormat, pair.ID) - forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination) - - err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) - require.NoError(t, err, "inserting rule should not return error") - - nat4RuleKey := genKey(natFormat, pair.ID) - nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination) - - err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) - require.NoError(t, err, "inserting rule should not return error") - - pair = routerPair{ - ID: "abc", - source: "fc00::1/128", - destination: "fc11::/64", - masquerade: true, - } - - forward6RuleKey := genKey(forwardingFormat, pair.ID) - forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination) - - err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) - require.NoError(t, err, "inserting rule should not return error") - - nat6RuleKey := genKey(natFormat, pair.ID) - nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination) - - err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) - require.NoError(t, err, "inserting rule should not return error") - - delete(manager.rules, ipv4) - delete(manager.rules, ipv6) - - err = manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4") - - foundRule, found := manager.rules[ipv4][forward4RuleKey] - require.True(t, found, "forwarding rule should exist in the map") - require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match") - - foundRule, found = manager.rules[ipv4][nat4RuleKey] - require.True(t, found, "nat rule should exist in the map") - require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match") - - require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6") - - foundRule, found = manager.rules[ipv6][forward6RuleKey] - require.True(t, found, "forwarding rule should exist in the map") - require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match") - - foundRule, found = manager.rules[ipv6][nat6RuleKey] - require.True(t, found, "nat rule should exist in the map") - require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match") -} - -func TestIptablesManager_InsertRoutingRules(t *testing.T) { - - if !isIptablesSupported() { - t.SkipNow() - } - - for _, testCase := range insertRuleTestCases { - t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) - iptablesClient := ipv4Client - if testCase.ipVersion == ipv6 { - iptablesClient = ipv6Client - } - - manager := &iptablesManager{ - ctx: ctx, - stop: cancel, - ipv4Client: ipv4Client, - ipv6Client: ipv6Client, - rules: make(map[string]map[string][]string), - } - - defer manager.CleanRoutingRules() - - err := manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - err = manager.InsertRoutingRules(testCase.inputPair) - require.NoError(t, err, "forwarding pair should be inserted") - - forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination) - - exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) - require.True(t, exists, "forwarding rule should exist") - - foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey] - require.True(t, found, "forwarding rule should exist in the manager map") - require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") - - inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination) - - exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) - require.True(t, exists, "income forwarding rule should exist") - - foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey] - require.True(t, found, "income forwarding rule should exist in the manager map") - require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") - - natRuleKey := genKey(natFormat, testCase.inputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) - - exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) - if testCase.inputPair.masquerade { - require.True(t, exists, "nat rule should be created") - foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey] - require.True(t, foundNat, "nat rule should exist in the map") - require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match") - } else { - require.False(t, exists, "nat rule should not be created") - _, foundNat := manager.rules[testCase.ipVersion][natRuleKey] - require.False(t, foundNat, "nat rule should not exist in the map") - } - - inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination) - - exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) - if testCase.inputPair.masquerade { - require.True(t, exists, "income nat rule should be created") - foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey] - require.True(t, foundNat, "income nat rule should exist in the map") - require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") - } else { - require.False(t, exists, "nat rule should not be created") - _, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey] - require.False(t, foundNat, "income nat rule should not exist in the map") - } - }) - } -} - -func TestIptablesManager_RemoveRoutingRules(t *testing.T) { - - if !isIptablesSupported() { - t.SkipNow() - } - - for _, testCase := range removeRuleTestCases { - t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) - iptablesClient := ipv4Client - if testCase.ipVersion == ipv6 { - iptablesClient = ipv6Client - } - - manager := &iptablesManager{ - ctx: ctx, - stop: cancel, - ipv4Client: ipv4Client, - ipv6Client: ipv6Client, - rules: make(map[string]map[string][]string), - } - - defer manager.CleanRoutingRules() - - err := manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination) - - err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination) - - err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - natRuleKey := genKey(natFormat, testCase.inputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) - - err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) - require.NoError(t, err, "inserting rule should not return error") - - inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination) - - err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...) - require.NoError(t, err, "inserting rule should not return error") - - delete(manager.rules, ipv4) - delete(manager.rules, ipv6) - - err = manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - err = manager.RemoveRoutingRules(testCase.inputPair) - require.NoError(t, err, "shouldn't return error") - - exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) - require.False(t, exists, "forwarding rule should not exist") - - _, found := manager.rules[testCase.ipVersion][forwardRuleKey] - require.False(t, found, "forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain) - require.False(t, exists, "income forwarding rule should not exist") - - _, found = manager.rules[testCase.ipVersion][inForwardRuleKey] - require.False(t, found, "income forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) - require.False(t, exists, "nat rule should not exist") - - _, found = manager.rules[testCase.ipVersion][natRuleKey] - require.False(t, found, "nat rule should exist in the manager map") - - exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...) - require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) - require.False(t, exists, "income nat rule should not exist") - - _, found = manager.rules[testCase.ipVersion][inNatRuleKey] - require.False(t, found, "income nat rule should exist in the manager map") - - }) - } -} diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 479ac873f..e8a4bd134 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" @@ -19,6 +20,7 @@ type Manager interface { UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string + EnableServerRouter(firewall firewall.Manager) error Stop() } @@ -35,19 +37,12 @@ type DefaultManager struct { notifier *notifier } -// NewManager returns a new route manager func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { - srvRouter, err := newServerRouter(ctx, wgInterface) - if err != nil { - log.Errorf("server router is not supported: %s", err) - } - mCTX, cancel := context.WithCancel(ctx) dm := &DefaultManager{ ctx: mCTX, stop: cancel, clientNetworks: make(map[string]*clientNetwork), - serverRouter: srvRouter, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, @@ -61,6 +56,15 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, return dm } +func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { + var err error + m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall) + if err != nil { + return err + } + return nil +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop() { m.stop() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 8970841a2..a1214cbb9 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" @@ -37,6 +38,10 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList } +func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { + panic("implement me") +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop() { if m.StopFunc != nil { diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go deleted file mode 100644 index 3ecfa9630..000000000 --- a/client/internal/routemanager/nftables_linux.go +++ /dev/null @@ -1,571 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - "github.com/google/nftables" - "github.com/google/nftables/binaryutil" - "github.com/google/nftables/expr" - log "github.com/sirupsen/logrus" -) - -const ( - nftablesTable = "netbird-rt" - nftablesRoutingForwardingChain = "netbird-rt-fwd" - nftablesRoutingNatChain = "netbird-rt-nat" - - userDataAcceptForwardRuleSrc = "frwacceptsrc" - userDataAcceptForwardRuleDst = "frwacceptdst" -) - -// constants needed to create nftable rules -const ( - ipv4Len = 4 - ipv4SrcOffset = 12 - ipv4DestOffset = 16 - ipv6Len = 16 - ipv6SrcOffset = 8 - ipv6DestOffset = 24 - exprDirectionSource = "source" - exprDirectionDestination = "destination" -) - -// some presets for building nftable rules -var ( - zeroXor = binaryutil.NativeEndian.PutUint32(0) - - zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...) - - exprAllowRelatedEstablished = []expr.Any{ - &expr.Ct{ - Register: 1, - SourceRegister: false, - Key: 0, - }, - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: 4, - Mask: []uint8{0x6, 0x0, 0x0, 0x0}, - Xor: zeroXor, - }, - &expr.Cmp{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - exprCounterAccept = []expr.Any{ - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } -) - -type nftablesManager struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - tableIPv4 *nftables.Table - tableIPv6 *nftables.Table - chains map[string]map[string]*nftables.Chain - rules map[string]*nftables.Rule - filterTable *nftables.Table - defaultForwardRules []*nftables.Rule - mux sync.Mutex -} - -func newNFTablesManager(parentCtx context.Context) *nftablesManager { - ctx, cancel := context.WithCancel(parentCtx) - - return &nftablesManager{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - chains: make(map[string]map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - defaultForwardRules: make([]*nftables.Rule, 2), - } -} - -// CleanRoutingRules cleans existing nftables rules from the system -func (n *nftablesManager) CleanRoutingRules() { - n.mux.Lock() - defer n.mux.Unlock() - log.Debug("flushing tables") - if n.tableIPv4 != nil && n.tableIPv6 != nil { - n.conn.FlushTable(n.tableIPv6) - n.conn.FlushTable(n.tableIPv4) - } - - if n.defaultForwardRules[0] != nil { - err := n.eraseDefaultForwardRule() - if err != nil { - log.Errorf("failed to delete forward rule: %s", err) - } - } - log.Debugf("flushing tables result in: %v error", n.conn.Flush()) -} - -// RestoreOrCreateContainers restores existing nftables containers (tables and chains) -// if they don't exist, we create them -func (n *nftablesManager) RestoreOrCreateContainers() error { - n.mux.Lock() - defer n.mux.Unlock() - - if n.tableIPv6 != nil && n.tableIPv4 != nil { - log.Debugf("nftables: containers already restored, skipping") - return nil - } - - tables, err := n.conn.ListTables() - if err != nil { - return fmt.Errorf("nftables: unable to list tables: %v", err) - } - - for _, table := range tables { - if table.Name == "filter" && table.Family == nftables.TableFamilyIPv4 { - log.Debugf("nftables: found filter table for ipv4") - n.filterTable = table - continue - } - if table.Name == nftablesTable { - if table.Family == nftables.TableFamilyIPv4 { - n.tableIPv4 = table - continue - } - n.tableIPv6 = table - } - } - - if n.tableIPv4 == nil { - n.tableIPv4 = n.conn.AddTable(&nftables.Table{ - Name: nftablesTable, - Family: nftables.TableFamilyIPv4, - }) - } - - if n.tableIPv6 == nil { - n.tableIPv6 = n.conn.AddTable(&nftables.Table{ - Name: nftablesTable, - Family: nftables.TableFamilyIPv6, - }) - } - - chains, err := n.conn.ListChains() - if err != nil { - return fmt.Errorf("nftables: unable to list chains: %v", err) - } - - n.chains[ipv4] = make(map[string]*nftables.Chain) - n.chains[ipv6] = make(map[string]*nftables.Chain) - - for _, chain := range chains { - switch { - case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4: - n.chains[ipv4][chain.Name] = chain - case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6: - n.chains[ipv6][chain.Name] = chain - } - } - - if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found { - n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ - Name: nftablesRoutingForwardingChain, - Table: n.tableIPv4, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityNATDest + 1, - Type: nftables.ChainTypeFilter, - }) - } - - if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found { - n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ - Name: nftablesRoutingNatChain, - Table: n.tableIPv4, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - } - - if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found { - n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{ - Name: nftablesRoutingForwardingChain, - Table: n.tableIPv6, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityNATDest + 1, - Type: nftables.ChainTypeFilter, - }) - } - - if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found { - n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{ - Name: nftablesRoutingNatChain, - Table: n.tableIPv6, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - } - - err = n.refreshRulesMap() - if err != nil { - return err - } - - n.checkOrCreateDefaultForwardingRules() - err = n.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to initialize table: %v", err) - } - return nil -} - -// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid -// duplicates and to get missing attributes that we don't have when adding new rules -func (n *nftablesManager) refreshRulesMap() error { - for _, registeredChains := range n.chains { - for _, chain := range registeredChains { - rules, err := n.conn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) - } - for _, rule := range rules { - if len(rule.UserData) > 0 { - n.rules[string(rule.UserData)] = rule - } - } - } - } - return nil -} - -func (n *nftablesManager) eraseDefaultForwardRule() error { - if n.defaultForwardRules[0] == nil { - return nil - } - - err := n.refreshDefaultForwardRule() - if err != nil { - return err - } - - for i, r := range n.defaultForwardRules { - err = n.conn.DelRule(r) - if err != nil { - log.Errorf("failed to delete forward rule (%d): %s", i, err) - } - n.defaultForwardRules[i] = nil - } - return nil -} - -func (n *nftablesManager) refreshDefaultForwardRule() error { - rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain) - if err != nil { - return fmt.Errorf("unable to list rules in forward chain: %s", err) - } - - found := false - for i, r := range n.defaultForwardRules { - for _, rule := range rules { - if string(rule.UserData) == string(r.UserData) { - n.defaultForwardRules[i] = rule - found = true - break - } - } - } - if !found { - return fmt.Errorf("unable to find forward accept rule") - } - - return nil -} - -func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error { - src := generateCIDRMatcherExpressions("source", sourceNetwork) - dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0") - - var exprs []expr.Any - exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - r := &nftables.Rule{ - Table: n.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: n.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleSrc), - } - - n.defaultForwardRules[0] = n.conn.AddRule(r) - - src = generateCIDRMatcherExpressions("source", "0.0.0.0/0") - dst = generateCIDRMatcherExpressions("destination", sourceNetwork) - - exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - r = &nftables.Rule{ - Table: n.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: n.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleDst), - } - - n.defaultForwardRules[1] = n.conn.AddRule(r) - return nil -} - -// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled -func (n *nftablesManager) checkOrCreateDefaultForwardingRules() { - _, foundIPv4 := n.rules[ipv4Forwarding] - if !foundIPv4 { - n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{ - Table: n.tableIPv4, - Chain: n.chains[ipv4][nftablesRoutingForwardingChain], - Exprs: exprAllowRelatedEstablished, - UserData: []byte(ipv4Forwarding), - }) - } - - _, foundIPv6 := n.rules[ipv6Forwarding] - if !foundIPv6 { - n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{ - Table: n.tableIPv6, - Chain: n.chains[ipv6][nftablesRoutingForwardingChain], - Exprs: exprAllowRelatedEstablished, - UserData: []byte(ipv6Forwarding), - }) - } -} - -// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (n *nftablesManager) InsertRoutingRules(pair routerPair) error { - n.mux.Lock() - defer n.mux.Unlock() - - err := n.refreshRulesMap() - if err != nil { - return err - } - - err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false) - if err != nil { - return err - } - err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false) - if err != nil { - return err - } - - if pair.masquerade { - err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true) - if err != nil { - return err - } - err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true) - if err != nil { - return err - } - } - - if n.defaultForwardRules[0] == nil && n.filterTable != nil { - err = n.acceptForwardRule(pair.source) - if err != nil { - log.Errorf("unable to create default forward rule: %s", err) - } - log.Debugf("default accept forward rule added") - } - - err = n.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) - } - return nil -} - -// insertRoutingRule inserts a nftable rule to the conn client flush queue -func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error { - - prefix := netip.MustParsePrefix(pair.source) - - sourceExp := generateCIDRMatcherExpressions("source", pair.source) - destExp := generateCIDRMatcherExpressions("destination", pair.destination) - - var expression []expr.Any - if isNat { - expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - } else { - expression = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - } - - ruleKey := genKey(format, pair.ID) - - _, exists := n.rules[ruleKey] - if exists { - err := n.removeRoutingRule(format, pair) - if err != nil { - return err - } - } - - if prefix.Addr().Unmap().Is4() { - n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{ - Table: n.tableIPv4, - Chain: n.chains[ipv4][chain], - Exprs: expression, - UserData: []byte(ruleKey), - }) - } else { - n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{ - Table: n.tableIPv6, - Chain: n.chains[ipv6][chain], - Exprs: expression, - UserData: []byte(ruleKey), - }) - } - return nil -} - -// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains -func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error { - n.mux.Lock() - defer n.mux.Unlock() - - err := n.refreshRulesMap() - if err != nil { - return err - } - - err = n.removeRoutingRule(forwardingFormat, pair) - if err != nil { - return err - } - - err = n.removeRoutingRule(inForwardingFormat, getInPair(pair)) - if err != nil { - return err - } - - err = n.removeRoutingRule(natFormat, pair) - if err != nil { - return err - } - - err = n.removeRoutingRule(inNatFormat, getInPair(pair)) - if err != nil { - return err - } - - if len(n.rules) == 2 && n.defaultForwardRules[0] != nil { - err := n.eraseDefaultForwardRule() - if err != nil { - log.Errorf("failed to delete default fwd rule: %s", err) - } - } - - err = n.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) - } - log.Debugf("nftables: removed rules for %s", pair.destination) - return nil -} - -// removeRoutingRule add a nftable rule to the removal queue and delete from rules map -func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error { - ruleKey := genKey(format, pair.ID) - - rule, found := n.rules[ruleKey] - if found { - ruleType := "forwarding" - if rule.Chain.Type == nftables.ChainTypeNAT { - ruleType = "nat" - } - - err := n.conn.DelRule(rule) - if err != nil { - return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err) - } - - log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination) - - delete(n.rules, ruleKey) - } - return nil -} - -// getPayloadDirectives get expression directives based on ip version and direction -func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { - switch { - case direction == exprDirectionSource && isIPv4: - return ipv4SrcOffset, ipv4Len, zeroXor - case direction == exprDirectionDestination && isIPv4: - return ipv4DestOffset, ipv4Len, zeroXor - case direction == exprDirectionSource && isIPv6: - return ipv6SrcOffset, ipv6Len, zeroXor6 - case direction == exprDirectionDestination && isIPv6: - return ipv6DestOffset, ipv6Len, zeroXor6 - default: - panic("no matched payload directive") - } -} - -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any { - ip, network, _ := net.ParseCIDR(cidr) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - - offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6()) - - return []expr.Any{ - // fetch src add - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: offSet, - Len: packetLen, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: packetLen, - Mask: network.Mask, - Xor: zeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: add.AsSlice(), - }, - } -} diff --git a/client/internal/routemanager/nftables_linux_test.go b/client/internal/routemanager/nftables_linux_test.go deleted file mode 100644 index d60d53e50..000000000 --- a/client/internal/routemanager/nftables_linux_test.go +++ /dev/null @@ -1,324 +0,0 @@ -//go:build !android - -package routemanager - -import ( - "context" - "testing" - - "github.com/google/nftables" - "github.com/google/nftables/expr" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal/checkfw" -) - -func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { - - if checkfw.Check() != checkfw.NFTABLES { - t.Skip("nftables not supported on this OS") - } - - manager := newNFTablesManager(context.TODO()) - - nftablesTestingClient := &nftables.Conn{} - - defer manager.CleanRoutingRules() - - err := manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") - require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4") - require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6") - require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6") - - pair := routerPair{ - ID: "abc", - source: "100.100.100.1/32", - destination: "100.100.100.0/24", - masquerade: true, - } - - sourceExp := generateCIDRMatcherExpressions("source", pair.source) - destExp := generateCIDRMatcherExpressions("destination", pair.destination) - - forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forward4RuleKey := genKey(forwardingFormat, pair.ID) - inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.tableIPv4, - Chain: manager.chains[ipv4][nftablesRoutingForwardingChain], - Exprs: forward4Exp, - UserData: []byte(forward4RuleKey), - }) - - nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - nat4RuleKey := genKey(natFormat, pair.ID) - - inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.tableIPv4, - Chain: manager.chains[ipv4][nftablesRoutingNatChain], - Exprs: nat4Exp, - UserData: []byte(nat4RuleKey), - }) - - err = nftablesTestingClient.Flush() - require.NoError(t, err, "shouldn't return error") - - pair = routerPair{ - ID: "xyz", - source: "fc00::1/128", - destination: "fc11::/64", - masquerade: true, - } - - sourceExp = generateCIDRMatcherExpressions("source", pair.source) - destExp = generateCIDRMatcherExpressions("destination", pair.destination) - - forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forward6RuleKey := genKey(forwardingFormat, pair.ID) - inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.tableIPv6, - Chain: manager.chains[ipv6][nftablesRoutingForwardingChain], - Exprs: forward6Exp, - UserData: []byte(forward6RuleKey), - }) - - nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - nat6RuleKey := genKey(natFormat, pair.ID) - - inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.tableIPv6, - Chain: manager.chains[ipv6][nftablesRoutingNatChain], - Exprs: nat6Exp, - UserData: []byte(nat6RuleKey), - }) - - err = nftablesTestingClient.Flush() - require.NoError(t, err, "shouldn't return error") - - manager.tableIPv4 = nil - manager.tableIPv6 = nil - - err = manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") - require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4") - require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6") - require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6") - - foundRule, found := manager.rules[forward4RuleKey] - require.True(t, found, "forwarding rule should exist in the map") - assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match") - - foundRule, found = manager.rules[nat4RuleKey] - require.True(t, found, "nat rule should exist in the map") - // match len of output as nftables client doesn't return expressions with masquerade expression - assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match") - - foundRule, found = manager.rules[forward6RuleKey] - require.True(t, found, "forwarding rule should exist in the map") - assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match") - - foundRule, found = manager.rules[nat6RuleKey] - require.True(t, found, "nat rule should exist in the map") - // match len of output as nftables client doesn't return expressions with masquerade expression - assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match") -} - -func TestNftablesManager_InsertRoutingRules(t *testing.T) { - if checkfw.Check() != checkfw.NFTABLES { - t.Skip("nftables not supported on this OS") - } - - for _, testCase := range insertRuleTestCases { - t.Run(testCase.name, func(t *testing.T) { - manager := newNFTablesManager(context.TODO()) - - nftablesTestingClient := &nftables.Conn{} - - defer manager.CleanRoutingRules() - - err := manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - err = manager.InsertRoutingRules(testCase.inputPair) - require.NoError(t, err, "forwarding pair should be inserted") - - sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source) - destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) - - found := 0 - for _, registeredChains := range manager.chains { - for _, chain := range registeredChains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") - found = 1 - } - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") - - if testCase.inputPair.masquerade { - natRuleKey := genKey(natFormat, testCase.inputPair.ID) - found := 0 - for _, registeredChains := range manager.chains { - for _, chain := range registeredChains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") - found = 1 - } - } - } - } - require.Equal(t, 1, found, "should find at least 1 rule to test") - } - - sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source) - destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination) - testingExpression = append(sourceExp, destExp...) //nolint:gocritic - inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID) - - found = 0 - for _, registeredChains := range manager.chains { - for _, chain := range registeredChains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match") - found = 1 - } - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") - - if testCase.inputPair.masquerade { - inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID) - found := 0 - for _, registeredChains := range manager.chains { - for _, chain := range registeredChains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match") - found = 1 - } - } - } - } - require.Equal(t, 1, found, "should find at least 1 rule to test") - } - }) - } -} - -func TestNftablesManager_RemoveRoutingRules(t *testing.T) { - if checkfw.Check() != checkfw.NFTABLES { - t.Skip("nftables not supported on this OS") - } - - for _, testCase := range removeRuleTestCases { - t.Run(testCase.name, func(t *testing.T) { - manager := newNFTablesManager(context.TODO()) - - nftablesTestingClient := &nftables.Conn{} - - defer manager.CleanRoutingRules() - - err := manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - table := manager.tableIPv4 - if testCase.ipVersion == ipv6 { - table = manager.tableIPv6 - } - - sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source) - destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination) - - forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID) - insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: table, - Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain], - Exprs: forwardExp, - UserData: []byte(forwardRuleKey), - }) - - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := genKey(natFormat, testCase.inputPair.ID) - - insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: table, - Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain], - Exprs: natExp, - UserData: []byte(natRuleKey), - }) - - sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source) - destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination) - - forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID) - insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: table, - Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain], - Exprs: forwardExp, - UserData: []byte(inForwardRuleKey), - }) - - natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID) - - insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: table, - Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain], - Exprs: natExp, - UserData: []byte(inNatRuleKey), - }) - - err = nftablesTestingClient.Flush() - require.NoError(t, err, "shouldn't return error") - - manager.tableIPv4 = nil - manager.tableIPv6 = nil - - err = manager.RestoreOrCreateContainers() - require.NoError(t, err, "shouldn't return error") - - err = manager.RemoveRoutingRules(testCase.inputPair) - require.NoError(t, err, "shouldn't return error") - - for _, registeredChains := range manager.chains { - for _, chain := range registeredChains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 { - require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist") - require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") - require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") - } - } - } - } - }) - } -} diff --git a/client/internal/routemanager/router_pair.go b/client/internal/routemanager/router_pair.go deleted file mode 100644 index 6836720f1..000000000 --- a/client/internal/routemanager/router_pair.go +++ /dev/null @@ -1,24 +0,0 @@ -package routemanager - -import ( - "net/netip" - - "github.com/netbirdio/netbird/route" -) - -type routerPair struct { - ID string - source string - destination string - masquerade bool -} - -func routeToRouterPair(source string, route *route.Route) routerPair { - parsed := netip.MustParsePrefix(source).Masked() - return routerPair{ - ID: route.ID, - source: parsed.String(), - destination: route.Network.Masked().String(), - masquerade: route.Masquerade, - } -} diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index d130acc00..7eafabd77 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -4,9 +4,10 @@ import ( "context" "fmt" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" ) -func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) { +func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 6df632329..20e500e79 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -4,11 +4,12 @@ package routemanager import ( "context" - "fmt" + "net/netip" "sync" log "github.com/sirupsen/logrus" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -17,16 +18,11 @@ type defaultServerRouter struct { mux sync.Mutex ctx context.Context routes map[string]*route.Route - firewall firewallManager + firewall firewall.Manager wgInterface *iface.WGIface } -func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) { - firewall, err := newFirewall(ctx) - if err != nil { - return nil, err - } - +func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) { return &defaultServerRouter{ ctx: ctx, routes: make(map[string]*route.Route), @@ -38,13 +34,6 @@ func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRou func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error { serverRoutesToRemove := make([]string, 0) - if len(routesMap) > 0 { - err := m.firewall.RestoreOrCreateContainers() - if err != nil { - return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err) - } - } - for routeID := range m.routes { update, found := routesMap[routeID] if !found || !update.IsEqual(m.routes[routeID]) { @@ -121,5 +110,22 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { } func (m *defaultServerRouter) cleanUp() { - m.firewall.CleanRoutingRules() + m.mux.Lock() + defer m.mux.Unlock() + for _, r := range m.routes { + err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r)) + if err != nil { + log.Warnf("failed to remove clean up route: %s", r.ID) + } + } +} + +func routeToRouterPair(source string, route *route.Route) firewall.RouterPair { + parsed := netip.MustParsePrefix(source).Masked() + return firewall.RouterPair{ + ID: route.ID, + Source: parsed.String(), + Destination: route.Network.Masked().String(), + Masquerade: route.Masquerade, + } } diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_nonandroid_test.go index 3646dc3da..2ae1e0ec4 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_nonandroid_test.go @@ -1,3 +1,5 @@ +//go:build !android + package routemanager import (