diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 6192c92aa..e4debc179 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -91,11 +91,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*rou var err error r.filterTable, err = r.loadFilterTable() if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, fmt.Errorf("load filter table: %w", err) - } + log.Warnf("failed to load filter table, skipping accept rules: %v", err) } return r, nil @@ -175,7 +171,7 @@ func (r *router) removeNatPreroutingRules() error { func (r *router) loadFilterTable() (*nftables.Table, error) { tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { - return nil, fmt.Errorf("unable to list tables: %v", err) + return nil, fmt.Errorf("list tables: %w", err) } for _, table := range tables { @@ -193,8 +189,6 @@ func (r *router) createContainers() error { Table: r.workTable, }) - insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) - prio := *nftables.ChainPriorityNATSource - 1 r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, @@ -236,9 +230,12 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) - // Add the single NAT rule that matches on mark - if err := r.addPostroutingRules(); err != nil { - return fmt.Errorf("add single nat rule: %v", err) + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.addPostroutingRules() + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("initialize tables: %v", err) } if err := r.addMSSClampingRules(); err != nil { @@ -250,11 +247,7 @@ func (r *router) createContainers() error { } if err := r.refreshRulesMap(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - if err := r.conn.Flush(); err != nil { - return fmt.Errorf("initialize tables: %v", err) + log.Errorf("failed to refresh rules: %s", err) } return nil @@ -695,7 +688,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { } // addPostroutingRules adds the masquerade rules -func (r *router) addPostroutingRules() error { +func (r *router) addPostroutingRules() { // First masquerade rule for traffic coming in from WireGuard interface exprs := []expr.Any{ // Match on the first fwmark @@ -761,8 +754,6 @@ func (r *router) addPostroutingRules() error { Chain: r.chains[chainNameRoutingNat], Exprs: exprs2, }) - - return nil } // addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. @@ -839,7 +830,7 @@ func (r *router) addMSSClampingRules() error { Exprs: exprsOut, }) - return nil + return r.conn.Flush() } // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls @@ -1068,7 +1059,7 @@ func (r *router) acceptFilterRulesNftables() error { } r.conn.InsertRule(inputRule) - return nil + return r.conn.Flush() } func (r *router) removeAcceptFilterRules() error { @@ -1196,7 +1187,7 @@ func (r *router) refreshRulesMap() error { for _, chain := range r.chains { rules, err := r.conn.GetRules(chain.Table, chain) if err != nil { - return fmt.Errorf(" unable to list rules: %v", err) + return fmt.Errorf("list rules: %w", err) } for _, rule := range rules { if len(rule.UserData) > 0 {