diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 40e1077be..e8f09a106 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -74,12 +74,12 @@ func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { return nil } - err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) + err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) if err != nil { return err } - err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) + err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) if err != nil { return err } @@ -101,6 +101,7 @@ func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, } delete(i.rules, ruleKey) } + err = i.iptablesClient.Insert(table, chain, 1, rule...) if err != nil { return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) @@ -317,6 +318,13 @@ func (i *routerManager) createChain(table, newChain string) error { return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) } + // Add the loopback return rule to the NAT chain + loopbackRule := []string{"-o", "lo", "-j", "RETURN"} + err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...) + if err != nil { + return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err) + } + err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") if err != nil { return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) @@ -326,6 +334,30 @@ func (i *routerManager) createChain(table, newChain string) error { return nil } +// addNATRule appends an iptables rule pair to the nat chain +func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(keyFormat, pair.ID) + rule := genRuleSpec(jump, pair.Source, pair.Destination) + existingRule, found := i.rules[ruleKey] + if found { + err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...) + if err != nil { + return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) + } + delete(i.rules, ruleKey) + } + + // inserting after loopback ignore rule + err := i.iptablesClient.Insert(table, chain, 2, rule...) + if err != nil { + return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) + } + + i.rules[ruleKey] = rule + + return nil +} + // genRuleSpec generates rule specification func genRuleSpec(jump, source, destination string) []string { return []string{"-s", source, "-d", destination, "-j", jump} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 8395fc270..a376c98c3 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -95,7 +95,7 @@ func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.InsertRoutingRules(pair) + return m.router.AddRoutingRules(pair) } func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go index 381136e50..71d5ac88e 100644 --- a/client/firewall/nftables/route_linux.go +++ b/client/firewall/nftables/route_linux.go @@ -22,6 +22,8 @@ const ( userDataAcceptForwardRuleSrc = "frwacceptsrc" userDataAcceptForwardRuleDst = "frwacceptdst" + + loopbackInterface = "lo\x00" ) // some presets for building nftable rules @@ -126,6 +128,22 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) + // Add RETURN rule for loopback interface + loRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(loopbackInterface), + }, + &expr.Verdict{Kind: expr.VerdictReturn}, + }, + } + r.conn.InsertRule(loRule) + err := r.refreshRulesMap() if err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) @@ -138,28 +156,28 @@ func (r *router) createContainers() error { return nil } -// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (r *router) InsertRoutingRules(pair manager.RouterPair) error { +// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain +func (r *router) AddRoutingRules(pair manager.RouterPair) error { err := r.refreshRulesMap() if err != nil { return err } - err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) + err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) if err != nil { return err } - err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) + err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) if err != nil { return err } if pair.Masquerade { - err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) + err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) if err != nil { return err } - err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) + err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) if err != nil { return err } @@ -177,8 +195,8 @@ func (r *router) InsertRoutingRules(pair manager.RouterPair) error { return nil } -// insertRoutingRule inserts a nftable rule to the conn client flush queue -func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { +// addRoutingRule inserts a nftable rule to the conn client flush queue +func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { sourceExp := generateCIDRMatcherExpressions(true, pair.Source) destExp := generateCIDRMatcherExpressions(false, pair.Destination) @@ -199,7 +217,7 @@ func (r *router) insertRoutingRule(format, chainName string, pair manager.Router } } - r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{ + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ Table: r.workTable, Chain: r.chains[chainName], Exprs: expression, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index aa1224a5a..913fbd5d2 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -47,7 +47,7 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { require.NoError(t, err, "shouldn't return error") - err = manager.InsertRoutingRules(testCase.InputPair) + err = manager.AddRoutingRules(testCase.InputPair) defer func() { _ = manager.RemoveRoutingRules(testCase.InputPair) }()