diff --git a/client/internal/routemanager/firewall_linux.go b/client/internal/routemanager/firewall_linux.go index 959724ed3..8b27c8967 100644 --- a/client/internal/routemanager/firewall_linux.go +++ b/client/internal/routemanager/firewall_linux.go @@ -6,8 +6,6 @@ import ( "context" "fmt" - "github.com/coreos/go-iptables/iptables" - "github.com/google/nftables" log "github.com/sirupsen/logrus" ) @@ -30,46 +28,13 @@ func genKey(format string, input string) string { // NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager func NewFirewall(parentCTX context.Context) firewallManager { - ctx, cancel := context.WithCancel(parentCTX) - - if isIptablesSupported() { - log.Debugf("iptables is supported") - ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - if !isIptablesClientAvailable(ipv4Client) { - log.Infof("iptables is missing for ipv4") - ipv4Client = nil - } - ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) - if !isIptablesClientAvailable(ipv6Client) { - log.Infof("iptables is missing for ipv6") - ipv6Client = nil - } - - return &iptablesManager{ - ctx: ctx, - stop: cancel, - ipv4Client: ipv4Client, - ipv6Client: ipv6Client, - rules: make(map[string]map[string][]string), - } + manager, err := newNFTablesManager(parentCTX) + if err == nil { + log.Debugf("nftables firewall manager will be used") + return manager } - - log.Debugf("iptables is not supported, using nftables") - - manager := &nftablesManager{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - chains: make(map[string]map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - } - - return manager -} - -func isIptablesClientAvailable(client *iptables.IPTables) bool { - _, err := client.ListChains("filter") - return err == nil + log.Debugf("fallback to iptables firewall manager: %s", err) + return newIptablesManager(parentCTX) } func getInPair(pair routerPair) routerPair { diff --git a/client/internal/routemanager/iptables_linux.go b/client/internal/routemanager/iptables_linux.go index be469b82a..b058278f3 100644 --- a/client/internal/routemanager/iptables_linux.go +++ b/client/internal/routemanager/iptables_linux.go @@ -49,6 +49,28 @@ type iptablesManager struct { mux sync.Mutex } +func newIptablesManager(parentCtx context.Context) *iptablesManager { + ctx, cancel := context.WithCancel(parentCtx) + ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if !isIptablesClientAvailable(ipv4Client) { + log.Infof("iptables is missing for ipv4") + ipv4Client = nil + } + ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if !isIptablesClientAvailable(ipv6Client) { + log.Infof("iptables is missing for ipv6") + ipv6Client = nil + } + + return &iptablesManager{ + ctx: ctx, + stop: cancel, + ipv4Client: ipv4Client, + ipv6Client: ipv6Client, + rules: make(map[string]map[string][]string), + } +} + // CleanRoutingRules cleans existing iptables resources that we created by the agent func (i *iptablesManager) CleanRoutingRules() { i.mux.Lock() @@ -453,3 +475,8 @@ func getIptablesRuleType(table string) string { } return ruleType } + +func isIptablesClientAvailable(client *iptables.IPTables) bool { + _, err := client.ListChains("filter") + return err == nil +} diff --git a/client/internal/routemanager/iptables_linux_test.go b/client/internal/routemanager/iptables_linux_test.go index a8db05e8a..c26355e56 100644 --- a/client/internal/routemanager/iptables_linux_test.go +++ b/client/internal/routemanager/iptables_linux_test.go @@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { t.SkipNow() } - ctx, cancel := context.WithCancel(context.TODO()) - ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6) - - manager := &iptablesManager{ - ctx: ctx, - stop: cancel, - ipv4Client: ipv4Client, - ipv6Client: ipv6Client, - rules: make(map[string]map[string][]string), - } + manager := newIptablesManager(context.TODO()) defer manager.CleanRoutingRules() @@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4") - exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) + exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain) require.True(t, exists, "forwarding rule should exist") - exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) + exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain) require.True(t, exists, "postrouting rule should exist") require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6") - exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) + exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain) require.True(t, exists, "forwarding rule should exist") - exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) + exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain) require.True(t, exists, "postrouting rule should exist") @@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { forward4RuleKey := genKey(forwardingFormat, pair.ID) forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination) - err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) + err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...) require.NoError(t, err, "inserting rule should not return error") nat4RuleKey := genKey(natFormat, pair.ID) nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination) - err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) + err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...) require.NoError(t, err, "inserting rule should not return error") pair = routerPair{ @@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { forward6RuleKey := genKey(forwardingFormat, pair.ID) forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination) - err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) + err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...) require.NoError(t, err, "inserting rule should not return error") nat6RuleKey := genKey(natFormat, pair.ID) nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination) - err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) + err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...) require.NoError(t, err, "inserting rule should not return error") delete(manager.rules, ipv4) diff --git a/client/internal/routemanager/nftables_linux.go b/client/internal/routemanager/nftables_linux.go index 4f4f82224..dde8c143c 100644 --- a/client/internal/routemanager/nftables_linux.go +++ b/client/internal/routemanager/nftables_linux.go @@ -81,6 +81,25 @@ type nftablesManager struct { mux sync.Mutex } +func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) { + ctx, cancel := context.WithCancel(parentCtx) + + mgr := &nftablesManager{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + chains: make(map[string]map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + } + + err := mgr.isSupported() + if err != nil { + return nil, err + } + + return mgr, nil +} + // CleanRoutingRules cleans existing nftables rules from the system func (n *nftablesManager) CleanRoutingRules() { n.mux.Lock() @@ -386,6 +405,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro return nil } +func (n *nftablesManager) isSupported() error { + _, err := n.conn.ListChains() + if err != nil { + return fmt.Errorf("nftables is not supported: %s", err) + } + return nil +} + // getPayloadDirectives get expression directives based on ip version and direction func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { switch { diff --git a/client/internal/routemanager/nftables_linux_test.go b/client/internal/routemanager/nftables_linux_test.go index 7ff8dd125..01fc38885 100644 --- a/client/internal/routemanager/nftables_linux_test.go +++ b/client/internal/routemanager/nftables_linux_test.go @@ -14,21 +14,16 @@ import ( func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - - manager := &nftablesManager{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - chains: make(map[string]map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), + manager, err := newNFTablesManager(context.TODO()) + if err != nil { + t.Fatalf("failed to create nftables manager: %s", err) } nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err := manager.RestoreOrCreateContainers() + err = manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6") @@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { for _, testCase := range insertRuleTestCases { t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - - manager := &nftablesManager{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - chains: make(map[string]map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), + manager, err := newNFTablesManager(context.TODO()) + if err != nil { + t.Fatalf("failed to create nftables manager: %s", err) } nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err := manager.RestoreOrCreateContainers() + err = manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") err = manager.InsertRoutingRules(testCase.inputPair) @@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { for _, testCase := range removeRuleTestCases { t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.TODO()) - - manager := &nftablesManager{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - chains: make(map[string]map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), + manager, err := newNFTablesManager(context.TODO()) + if err != nil { + t.Fatalf("failed to create nftables manager: %s", err) } nftablesTestingClient := &nftables.Conn{} defer manager.CleanRoutingRules() - err := manager.RestoreOrCreateContainers() + err = manager.RestoreOrCreateContainers() require.NoError(t, err, "shouldn't return error") table := manager.tableIPv4