diff --git a/client/firewall/nftables/external_chain_monitor_integration_linux_test.go b/client/firewall/nftables/external_chain_monitor_integration_linux_test.go new file mode 100644 index 000000000..3c4e3f44d --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_integration_linux_test.go @@ -0,0 +1,76 @@ +//go:build linux + +package nftables + +import ( + "os" + "sync/atomic" + "testing" + "time" + + "github.com/google/nftables" + "github.com/stretchr/testify/require" +) + +// TestExternalChainMonitorRootIntegration verifies that adding a new chain +// in an external (non-netbird) filter table triggers the reconciler. +// Requires CAP_NET_ADMIN; skip otherwise. +func TestExternalChainMonitorRootIntegration(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("root required") + } + + calls := make(chan struct{}, 8) + var count atomic.Int32 + rec := &countingReconciler{calls: calls, count: &count} + + m := newExternalChainMonitor(rec) + m.start() + t.Cleanup(m.stop) + + // Give the netlink subscription a moment to register. + time.Sleep(200 * time.Millisecond) + + conn := &nftables.Conn{} + table := conn.AddTable(&nftables.Table{ + Name: "nbmon_integration_test", + Family: nftables.TableFamilyINet, + }) + t.Cleanup(func() { + cleanup := &nftables.Conn{} + cleanup.DelTable(table) + _ = cleanup.Flush() + }) + + chain := conn.AddChain(&nftables.Chain{ + Name: "filter_INPUT", + Table: table, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + Type: nftables.ChainTypeFilter, + }) + _ = chain + require.NoError(t, conn.Flush(), "create external test chain") + + select { + case <-calls: + // success + case <-time.After(3 * time.Second): + t.Fatalf("reconcile was not invoked after creating an external chain") + } + require.GreaterOrEqual(t, count.Load(), int32(1)) +} + +type countingReconciler struct { + calls chan struct{} + count *atomic.Int32 +} + +func (c *countingReconciler) reconcileExternalChains() error { + c.count.Add(1) + select { + case c.calls <- struct{}{}: + default: + } + return nil +} diff --git a/client/firewall/nftables/external_chain_monitor_linux.go b/client/firewall/nftables/external_chain_monitor_linux.go new file mode 100644 index 000000000..2a2e04c09 --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_linux.go @@ -0,0 +1,199 @@ +package nftables + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/nftables" + log "github.com/sirupsen/logrus" +) + +const ( + externalMonitorReconcileDelay = 500 * time.Millisecond + externalMonitorInitInterval = 5 * time.Second + externalMonitorMaxInterval = 5 * time.Minute + externalMonitorRandomization = 0.5 +) + +// externalChainReconciler re-applies passthrough accept rules to external +// nftables chains. Implementations must be safe to call from the monitor +// goroutine; the Manager locks its mutex internally. +type externalChainReconciler interface { + reconcileExternalChains() error +} + +// externalChainMonitor watches nftables netlink events and triggers a +// reconcile when a new table or chain appears (e.g. after +// `firewall-cmd --reload`). Netlink errors trigger exponential-backoff +// reconnect. +type externalChainMonitor struct { + reconciler externalChainReconciler + + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +func newExternalChainMonitor(r externalChainReconciler) *externalChainMonitor { + return &externalChainMonitor{reconciler: r} +} + +func (m *externalChainMonitor) start() { + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + m.done = make(chan struct{}) + + go m.run(ctx) +} + +func (m *externalChainMonitor) stop() { + m.mu.Lock() + cancel := m.cancel + done := m.done + m.cancel = nil + m.done = nil + m.mu.Unlock() + + if cancel == nil { + return + } + cancel() + <-done +} + +func (m *externalChainMonitor) run(ctx context.Context) { + defer close(m.done) + + bo := &backoff.ExponentialBackOff{ + InitialInterval: externalMonitorInitInterval, + RandomizationFactor: externalMonitorRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: externalMonitorMaxInterval, + MaxElapsedTime: 0, + Clock: backoff.SystemClock, + } + bo.Reset() + + for ctx.Err() == nil { + err := m.watch(ctx) + if ctx.Err() != nil { + return + } + + delay := bo.NextBackOff() + log.Warnf("external chain monitor: %v, reconnecting in %s", err, delay) + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + } +} + +func (m *externalChainMonitor) watch(ctx context.Context) error { + events, closeMon, err := m.subscribe() + if err != nil { + return err + } + defer closeMon() + + debounce := time.NewTimer(time.Hour) + if !debounce.Stop() { + <-debounce.C + } + defer debounce.Stop() + + pending := false + for { + select { + case <-ctx.Done(): + return nil + case <-debounce.C: + pending = false + m.reconcile() + case ev, ok := <-events: + if !ok { + return errors.New("monitor channel closed") + } + if ev.Error != nil { + return fmt.Errorf("monitor event: %w", ev.Error) + } + if !isRelevantMonitorEvent(ev) { + continue + } + resetDebounce(debounce, pending) + pending = true + } + } +} + +func (m *externalChainMonitor) subscribe() (chan *nftables.MonitorEvent, func(), error) { + conn := &nftables.Conn{} + mon := nftables.NewMonitor( + nftables.WithMonitorAction(nftables.MonitorActionNew), + nftables.WithMonitorObject(nftables.MonitorObjectChains|nftables.MonitorObjectTables), + ) + events, err := conn.AddMonitor(mon) + if err != nil { + return nil, nil, fmt.Errorf("add netlink monitor: %w", err) + } + return events, func() { _ = mon.Close() }, nil +} + +// resetDebounce reschedules a pending debounce timer without leaking a stale +// fire on its channel. pending must reflect whether the timer is armed. +func resetDebounce(t *time.Timer, pending bool) { + if pending && !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(externalMonitorReconcileDelay) +} + +func (m *externalChainMonitor) reconcile() { + if err := m.reconciler.reconcileExternalChains(); err != nil { + log.Warnf("reconcile external chain rules: %v", err) + } +} + +// isRelevantMonitorEvent returns true for table/chain creation events on +// families we care about. The reconciler filters to actual external filter +// chains. +func isRelevantMonitorEvent(ev *nftables.MonitorEvent) bool { + switch ev.Type { + case nftables.MonitorEventTypeNewChain: + chain, ok := ev.Data.(*nftables.Chain) + if !ok || chain == nil || chain.Table == nil { + return false + } + return isMonitoredFamily(chain.Table.Family) + case nftables.MonitorEventTypeNewTable: + table, ok := ev.Data.(*nftables.Table) + if !ok || table == nil { + return false + } + return isMonitoredFamily(table.Family) + } + return false +} + +func isMonitoredFamily(family nftables.TableFamily) bool { + switch family { + case nftables.TableFamilyIPv4, nftables.TableFamilyIPv6, nftables.TableFamilyINet: + return true + } + return false +} diff --git a/client/firewall/nftables/external_chain_monitor_linux_test.go b/client/firewall/nftables/external_chain_monitor_linux_test.go new file mode 100644 index 000000000..1a37faca2 --- /dev/null +++ b/client/firewall/nftables/external_chain_monitor_linux_test.go @@ -0,0 +1,137 @@ +package nftables + +import ( + "testing" + + "github.com/google/nftables" + "github.com/stretchr/testify/assert" +) + +func TestIsMonitoredFamily(t *testing.T) { + tests := []struct { + family nftables.TableFamily + want bool + }{ + {nftables.TableFamilyIPv4, true}, + {nftables.TableFamilyIPv6, true}, + {nftables.TableFamilyINet, true}, + {nftables.TableFamilyARP, false}, + {nftables.TableFamilyBridge, false}, + {nftables.TableFamilyNetdev, false}, + {nftables.TableFamilyUnspecified, false}, + } + for _, tc := range tests { + assert.Equal(t, tc.want, isMonitoredFamily(tc.family), "family=%d", tc.family) + } +} + +func TestIsRelevantMonitorEvent(t *testing.T) { + inetTable := &nftables.Table{Name: "firewalld", Family: nftables.TableFamilyINet} + ipTable := &nftables.Table{Name: "filter", Family: nftables.TableFamilyIPv4} + arpTable := &nftables.Table{Name: "arp", Family: nftables.TableFamilyARP} + + tests := []struct { + name string + ev *nftables.MonitorEvent + want bool + }{ + { + name: "new chain in inet firewalld", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable}, + }, + want: true, + }, + { + name: "new chain in ip filter", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "INPUT", Table: ipTable}, + }, + want: true, + }, + { + name: "new chain in unwatched arp family", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "x", Table: arpTable}, + }, + want: false, + }, + { + name: "new table inet", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewTable, + Data: inetTable, + }, + want: true, + }, + { + name: "del chain (we only act on new)", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeDelChain, + Data: &nftables.Chain{Name: "filter_INPUT", Table: inetTable}, + }, + want: false, + }, + { + name: "chain with nil table", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: &nftables.Chain{Name: "x"}, + }, + want: false, + }, + { + name: "nil data", + ev: &nftables.MonitorEvent{ + Type: nftables.MonitorEventTypeNewChain, + Data: (*nftables.Chain)(nil), + }, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, isRelevantMonitorEvent(tc.ev)) + }) + } +} + +// fakeReconciler records reconcile invocations for debounce tests. +type fakeReconciler struct { + calls chan struct{} +} + +func (f *fakeReconciler) reconcileExternalChains() error { + f.calls <- struct{}{} + return nil +} + +func TestExternalChainMonitorStopWithoutStart(t *testing.T) { + m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)}) + // Must not panic or block. + m.stop() +} + +func TestExternalChainMonitorDoubleStart(t *testing.T) { + // start() twice should be a no-op; stop() cleans up once. + // We avoid exercising the netlink watch loop here because it needs root. + m := newExternalChainMonitor(&fakeReconciler{calls: make(chan struct{}, 1)}) + + // Replace run with a stub that just waits for cancel, so start() stays + // deterministic without opening a netlink socket. + origDone := make(chan struct{}) + m.done = origDone + m.cancel = func() { close(origDone) } + + // Second start should be a no-op (cancel already set). + m.start() + assert.NotNil(t, m.cancel) + + m.stop() + assert.Nil(t, m.cancel) + assert.Nil(t, m.done) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index e53209ecc..68b2f2b1a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -59,6 +59,8 @@ type Manager struct { notrackOutputChain *nftables.Chain notrackPreroutingChain *nftables.Chain + + extMonitor *externalChainMonitor } // Create nftables firewall manager @@ -88,6 +90,8 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { } } + m.extMonitor = newExternalChainMonitor(m) + return m, nil } @@ -142,9 +146,34 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { m.persistState(stateManager) + // Start after initFirewall has installed the baseline external-chain + // accept rules. start() is idempotent across Init/Close/Init cycles. + m.extMonitor.start() + return nil } +// reconcileExternalChains re-applies passthrough accept rules to external +// filter chains for both IPv4 and IPv6 routers. Called by the monitor when +// tables or chains appear (e.g. after firewalld reloads). +func (m *Manager) reconcileExternalChains() error { + m.mutex.Lock() + defer m.mutex.Unlock() + + var merr *multierror.Error + if m.router != nil { + if err := m.router.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("v4: %w", err)) + } + } + if m.hasIPv6() { + if err := m.router6.acceptExternalChainsRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("v6: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + func (m *Manager) initFirewall() error { workTable, err := m.createWorkTable() if err != nil { @@ -409,6 +438,8 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { // Close closes the firewall manager func (m *Manager) Close(stateManager *statemanager.Manager) error { + m.extMonitor.stop() + m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index dc714fb5c..530db5e82 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1157,83 +1157,122 @@ func (r *router) acceptExternalChainsRules() error { } intf := ifname(r.wgIface.Name()) - for _, chain := range chains { - if chain.Hooknum == nil { - log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) - continue - } - - log.Debugf("adding accept rules to external %s chain: %s %s/%s", - hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) - - switch *chain.Hooknum { - case *nftables.ChainHookForward: - r.insertForwardAcceptRules(chain, intf) - case *nftables.ChainHookInput: - r.insertInputAcceptRule(chain, intf) - } + r.applyExternalChainAccept(chain, intf) } if err := r.conn.Flush(); err != nil { return fmt.Errorf("flush external chain rules: %w", err) } - return nil } +func (r *router) applyExternalChainAccept(chain *nftables.Chain, intf []byte) { + if chain.Hooknum == nil { + log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name) + return + } + + log.Debugf("adding accept rules to external %s chain: %s %s/%s", + hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name) + + switch *chain.Hooknum { + case *nftables.ChainHookForward: + r.insertForwardAcceptRules(chain, intf) + case *nftables.ChainHookInput: + r.insertInputAcceptRule(chain, intf) + } +} + func (r *router) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) { - iifRule := &nftables.Rule{ + existing, err := r.existingNetbirdRulesInChain(chain) + if err != nil { + log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err) + return + } + r.insertForwardIifRule(chain, intf, existing) + r.insertForwardOifEstablishedRule(chain, intf, existing) +} + +func (r *router) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) { + if existing[userDataAcceptForwardRuleIif] { + return + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictAccept}, }, UserData: []byte(userDataAcceptForwardRuleIif), - } - r.conn.InsertRule(iifRule) + }) +} - oifExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, +func (r *router) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) { + if existing[userDataAcceptForwardRuleOif] { + return } - oifRule := &nftables.Rule{ + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, - Exprs: append(oifExprs, getEstablishedExprs(2)...), + Exprs: append(exprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), - } - r.conn.InsertRule(oifRule) + }) } func (r *router) insertInputAcceptRule(chain *nftables.Chain, intf []byte) { - inputRule := &nftables.Rule{ + existing, err := r.existingNetbirdRulesInChain(chain) + if err != nil { + log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err) + return + } + if existing[userDataAcceptInputRule] { + return + } + r.conn.InsertRule(&nftables.Rule{ Table: chain.Table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictAccept}, }, UserData: []byte(userDataAcceptInputRule), + }) +} + +// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive. +func (r *router) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return nil, fmt.Errorf("list rules: %w", err) } - r.conn.InsertRule(inputRule) + present := map[string]bool{} + for _, rule := range rules { + if !isNetbirdAcceptRuleTag(rule.UserData) { + continue + } + present[string(rule.UserData)] = true + } + return present, nil +} + +func isNetbirdAcceptRuleTag(userData []byte) bool { + switch string(userData) { + case userDataAcceptForwardRuleIif, + userDataAcceptForwardRuleOif, + userDataAcceptInputRule: + return true + } + return false } func (r *router) removeAcceptFilterRules() error {