diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e4debc179..7f95992da 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -27,7 +27,11 @@ import ( ) const ( - tableNat = "nat" + tableNat = "nat" + tableMangle = "mangle" + tableRaw = "raw" + tableSecurity = "security" + chainNameNatPrerouting = "PREROUTING" chainNameRoutingFw = "netbird-rt-fwd" chainNameRoutingNat = "netbird-rt-postrouting" @@ -91,7 +95,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou var err error r.filterTable, err = r.loadFilterTable() if err != nil { - log.Warnf("failed to load filter table, skipping accept rules: %v", err) + log.Debugf("ip filter table not found: %v", err) } return r, nil @@ -183,6 +187,33 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { return nil, errFilterTableNotFound } +func hookName(hook *nftables.ChainHook) string { + if hook == nil { + return "unknown" + } + switch *hook { + case *nftables.ChainHookForward: + return chainNameForward + case *nftables.ChainHookInput: + return chainNameInput + default: + return fmt.Sprintf("hook(%d)", *hook) + } +} + +func familyName(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + default: + return fmt.Sprintf("family(%d)", family) + } +} + func (r *router) createContainers() error { r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, @@ -930,8 +961,21 @@ func (r *router) RemoveAllLegacyRouteRules() error { // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { + var merr *multierror.Error + + if err := r.acceptFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) acceptFilterTableRules() error { if r.filterTable == nil { - log.Debugf("table 'filter' not found for forward rules, skipping accept rules") return nil } @@ -944,11 +988,11 @@ func (r *router) acceptForwardRules() error { // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { - // filter table exists but iptables is not + // iptables is not available but the filter table exists log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptFilterRulesNftables() + return r.acceptFilterRulesNftables(r.filterTable) } return r.acceptFilterRulesIptables(ipt) @@ -959,7 +1003,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err)) } else { log.Debugf("added iptables forward rule: %v", rule) } @@ -967,7 +1011,7 @@ func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { inputRule := r.getAcceptInputRule() if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err)) } else { log.Debugf("added iptables input rule: %v", inputRule) } @@ -987,18 +1031,70 @@ func (r *router) getAcceptInputRule() []string { return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} } -func (r *router) acceptFilterRulesNftables() error { +// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables. +// This is used when iptables is not available. +func (r *router) acceptFilterRulesNftables(table *nftables.Table) error { intf := ifname(r.wgIface.Name()) + forwardChain := &nftables.Chain{ + Name: chainNameForward, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + } + r.insertForwardAcceptRules(forwardChain, intf) + + inputChain := &nftables.Chain{ + Name: chainNameInput, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + } + r.insertInputAcceptRule(inputChain, intf) + + return r.conn.Flush() +} + +// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables). +// It dynamically finds chains at call time to handle chains that may have been created after startup. +func (r *router) acceptExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + intf := ifname(r.wgIface.Name()) + + for _, chain := range chains { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + continue + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush external chain rules: %w", err) + } + + return nil +} + +func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { iifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -1021,30 +1117,19 @@ func (r *router) acceptFilterRulesNftables() error { Data: intf, }, } - oifRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameForward, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) +} +func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { inputRule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: chainNameInput, - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookInput, - Priority: nftables.ChainPriorityFilter, - }, + Table: chain.Table, + Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -1058,32 +1143,44 @@ func (r *router) acceptFilterRulesNftables() error { UserData: []byte(userDataAcceptInputRule), } r.conn.InsertRule(inputRule) - - return r.conn.Flush() } func (r *router) removeAcceptFilterRules() error { + var merr *multierror.Error + + if err := r.removeFilterTableRules(); err != nil { + merr = multierror.Append(merr, err) + } + + if err := r.removeExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) removeFilterTableRules() error { if r.filterTable == nil { return nil } ipt, err := iptables.New() if err != nil { - log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptFilterRulesNftables() + log.Debugf("iptables not available, using nftables to remove filter rules: %v", err) + return r.removeAcceptRulesFromTable(r.filterTable) } return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptFilterRulesNftables() error { - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) +func (r *router) removeAcceptRulesFromTable(table *nftables.Table) error { + chains, err := r.conn.ListChainsOfTableFamily(table.Family) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { + if chain.Table.Name != table.Name { continue } @@ -1091,27 +1188,101 @@ func (r *router) removeAcceptFilterRulesNftables() error { continue } - rules, err := r.conn.GetRules(r.filterTable, chain) + if err := r.removeAcceptRulesFromChain(table, chain); err != nil { + return err + } + } + + return r.conn.Flush() +} + +func (r *router) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error { + rules, err := r.conn.GetRules(table, chain) + if err != nil { + return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err) + } + } + } + return nil +} + +// removeExternalChainsRules removes our accept rules from all external chains. +// This is deterministic - it scans for chains at removal time rather than relying on saved state, +// ensuring cleanup works even after a crash or if chains changed. +func (r *router) removeExternalChainsRules() error { + chains := r.findExternalChains() + if len(chains) == 0 { + return nil + } + + for _, chain := range chains { + if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil { + log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err) + } + } + + return r.conn.Flush() +} + +// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks. +// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal). +func (r *router) findExternalChains() []*nftables.Chain { + var chains []*nftables.Chain + + families := []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyINet} + + for _, family := range families { + allChains, err := r.conn.ListChainsOfTableFamily(family) if err != nil { - return fmt.Errorf("get rules: %v", err) + log.Debugf("list chains for family %d: %v", family, err) + continue } - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } + for _, chain := range allChains { + if r.isExternalChain(chain) { + chains = append(chains, chain) } } } - if err := r.conn.Flush(); err != nil { - return fmt.Errorf(flushError, err) + return chains +} + +func (r *router) isExternalChain(chain *nftables.Chain) bool { + if r.workTable != nil && chain.Table.Name == r.workTable.Name { + return false } - return nil + // Skip all iptables-managed tables in the ip family + if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { + return false + } + + if chain.Type != nftables.ChainTypeFilter { + return false + } + + if chain.Hooknum == nil { + return false + } + + return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput +} + +func isIptablesTable(name string) bool { + switch name { + case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity: + return true + } + return false } func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { @@ -1119,13 +1290,13 @@ func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err)) } } inputRule := r.getAcceptInputRule() if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err)) } return nberrors.FormatErrorOrNil(merr)