diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 7ee33118b..72e6a5c68 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // // If comment argument is empty firewall manager should set // rule ID as comment for the rule + // + // Note: Callers should call Flush() after adding rules to ensure + // they are applied to the kernel and rule handles are refreshed. AddPeerFiltering( id []byte, ip net.IP, diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..a9d066e2f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -29,8 +29,6 @@ const ( chainNameForwardFilter = "netbird-acl-forward-filter" chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePostrouting = "netbird-mangle-postrouting" - - allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) const flushError = "flush: %w" @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { // createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, &expr.Verdict{ Kind: expr.VerdictAccept, }, @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering( action firewall.Action, ipset *nftables.Set, ) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) + ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ nftRule: r.nftRule, @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering( } if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf(flushError, err) + return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err) } ruleStruct := &Rule{ - nftRule: nftRule, + nftRule: nftRule, + // best effort mangle rule mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt }, ) - return m.rConn.AddRule(&nftables.Rule{ + nfRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: m.chainPrerouting, Exprs: preroutingExprs, UserData: userData, }) + + if err := m.rConn.Flush(); err != nil { + log.Errorf("failed to flush mangle rule %s: %v", string(userData), err) + return nil + } + + return nfRule } func (m *AclManager) createDefaultChains() (err error) { @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { - rulesetID := ":" +func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" + string(proto) + ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d864914fe..bd19f1067 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -1,11 +1,11 @@ package nftables import ( - "bytes" "context" "fmt" "net" "net/netip" + "os" "sync" "github.com/google/nftables" @@ -19,13 +19,22 @@ import ( ) const ( - // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client tableNameNetbird = "netbird" + // envTableName is the environment variable to override the table name + envTableName = "NB_NFTABLES_TABLE" tableNameFilter = "filter" chainNameInput = "INPUT" ) +func getTableName() string { + if name := os.Getenv(envTableName); name != "" { + return name + } + return tableNameNetbird +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string @@ -50,7 +59,7 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { wgIface: wgIface, } - workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} var err error m.router, err = newRouter(workTable, wgIface, mtu) @@ -198,44 +207,11 @@ func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - err := m.aclManager.createDefaultAllowRules() - if err != nil { - return fmt.Errorf("failed to create default allow rules: %v", err) + if err := m.aclManager.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create default allow rules: %w", err) } - - chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - - var chain *nftables.Chain - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - chain = c - break - } - } - - if chain == nil { - log.Debugf("chain INPUT not found. Skipping add allow netbird rule") - return nil - } - - rules, err := m.rConn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) - } - - if rule := m.detectAllowNetbirdRule(rules); rule != nil { - log.Debugf("allow netbird rule already exists: %v", rule) - return nil - } - - m.applyAllowNetbirdRules(chain) - - err = m.rConn.Flush() - if err != nil { - return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush allow input netbird rules: %w", err) } return nil @@ -251,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.resetNetbirdInputRules(); err != nil { - return fmt.Errorf("reset netbird input rules: %v", err) - } - if err := m.router.Reset(); err != nil { return fmt.Errorf("reset router: %v", err) } @@ -274,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nil } -func (m *Manager) resetNetbirdInputRules() error { - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list chains: %w", err) - } - - m.deleteNetbirdInputRules(chains) - - return nil -} - -func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - - m.deleteMatchingRules(rules) - } - } -} - -func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } -} - func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } @@ -399,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { return nil, fmt.Errorf("list of tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } -func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { - rule := &nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, - UserData: []byte(allowNetbirdInputRuleID), - } - _ = m.rConn.InsertRule(rule) -} - -func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { - ifName := ifname(m.wgIface.Name()) - for _, rule := range existedRules { - if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { - if len(rule.Exprs) < 4 { - if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { - continue - } - if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { - continue - } - return rule - } - } - } - return nil -} - func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { rule := &nftables.Rule{ Table: table, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0a2c79186..6192c92aa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -37,6 +37,7 @@ const ( userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + userDataAcceptInputRule = "inputaccept" dnatSuffix = "_dnat" snatSuffix = "_snat" @@ -103,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou func (r *router) init(workTable *nftables.Table) error { r.workTable = workTable - if err := r.removeAcceptForwardRules(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + if err := r.removeAcceptFilterRules(); err != nil { + log.Errorf("failed to clean up rules from filter table: %s", err) } if err := r.createContainers(); err != nil { @@ -118,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error { return nil } -// Reset cleans existing nftables default forward rules from the system +// Reset cleans existing nftables filter table rules from the system func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() var merr *multierror.Error - if err := r.removeAcceptForwardRules(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + if err := r.removeAcceptFilterRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } if err := r.removeNatPreroutingRules(); err != nil { @@ -936,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +// This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") @@ -945,7 +947,7 @@ func (r *router) acceptForwardRules() error { fw := "iptables" defer func() { - log.Debugf("Used %s to add accept forward rules", fw) + log.Debugf("Used %s to add accept forward and input rules", fw) }() // Try iptables first and fallback to nftables if iptables is not available @@ -955,22 +957,30 @@ func (r *router) acceptForwardRules() error { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptForwardRulesNftables() + return r.acceptFilterRulesNftables() } - return r.acceptForwardRulesIptables(ipt) + return r.acceptFilterRulesIptables(ipt) } -func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) } else { - log.Debugf("added iptables rule: %v", rule) + log.Debugf("added iptables forward rule: %v", rule) } } + inputRule := r.getAcceptInputRule() + if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + } else { + log.Debugf("added iptables input rule: %v", inputRule) + } + return nberrors.FormatErrorOrNil(merr) } @@ -982,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string { } } -func (r *router) acceptForwardRulesNftables() error { +func (r *router) getAcceptInputRule() []string { + return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} +} + +func (r *router) acceptFilterRulesNftables() error { intf := ifname(r.wgIface.Name()) - // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ @@ -1018,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error { }, } - // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -1031,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error { Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } - r.conn.InsertRule(oifRule) + inputRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: chainNameInput, + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptInputRule), + } + r.conn.InsertRule(inputRule) + return nil } -func (r *router) removeAcceptForwardRules() error { +func (r *router) removeAcceptFilterRules() error { if r.filterTable == nil { return nil } - // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptForwardRulesNftables() + return r.removeAcceptFilterRulesNftables() } - return r.removeAcceptForwardRulesIptables(ipt) + return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptForwardRulesNftables() error { +func (r *router) removeAcceptFilterRulesNftables() error { chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + if chain.Table.Name != r.filterTable.Name { + continue + } + + if chain.Name != chainNameForward && chain.Name != chainNameInput { continue } @@ -1070,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error { for _, rule := range rules { if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("delete rule: %v", err) } @@ -1085,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error { return nil } -func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) } } + inputRule := r.getAcceptInputRule() + if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index b26836d17..58b88d9ef 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -15,6 +15,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -134,6 +135,8 @@ func (m *Manager) Stop(ctx context.Context) error { } } + m.unregisterNetstackServices() + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -158,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error { dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add udp firewall rule: %w", err) } - m.fwRules = dnsRules tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add tcp firewall rule: %w", err) } + + if err := m.firewall.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + m.fwRules = dnsRules m.tcpRules = tcpRules + m.registerNetstackServices() + return nil } +func (m *Manager) registerNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, m.serverPort) + registrar.RegisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + +func (m *Manager) unregisterNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort) + registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error for _, rule := range m.fwRules { diff --git a/client/internal/engine.go b/client/internal/engine.go index 0c7bd9f0a..3c7d52cb3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -298,17 +298,12 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } + e.stopDNSForwarder() + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.dnsForwardMgr != nil { - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - if e.srWatcher != nil { e.srWatcher.Close() } @@ -1873,7 +1868,6 @@ func (e *Engine) updateDNSForwarder( func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) - e.registerDNSServices() if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) @@ -1893,34 +1887,9 @@ func (e *Engine) stopDNSForwarder() { log.Errorf("failed to stop DNS forward: %v", err) } - e.unregisterDNSServices() e.dnsForwardMgr = nil } -func (e *Engine) registerDNSServices() { - if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { - if registrar, ok := e.firewall.(interface { - RegisterNetstackService(protocol nftypes.Protocol, port uint16) - }); ok { - registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) - registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) - log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) - } - } -} - -func (e *Engine) unregisterDNSServices() { - if netstackNet := e.wgInterface.GetNet(); netstackNet != nil { - if registrar, ok := e.firewall.(interface { - UnregisterNetstackService(protocol nftypes.Protocol, port uint16) - }); ok { - registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort) - registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort) - log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort) - } - } -} - func (e *Engine) GetNet() (*netstack.Net, error) { e.syncMsgMux.Lock() intf := e.wgInterface