diff --git a/client/firewall/iptables/manager.go b/client/firewall/iptables/manager.go index 014eadb6c..72cf6d1b2 100644 --- a/client/firewall/iptables/manager.go +++ b/client/firewall/iptables/manager.go @@ -12,12 +12,13 @@ import ( ) const ( - chainFilterName = "NETBIRD-ACL" + // ChainFilterName is the name of the chain that is used for filtering by the Netbird client + ChainFilterName = "NETBIRD-ACL" ) // Manager of iptables firewall type Manager struct { - mutex *sync.Mutex + mutex sync.Mutex ruleCnt int ipv4Client *iptables.IPTables @@ -58,6 +59,16 @@ func (m *Manager) AddFiltering( ) (fw.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() + client := m.client(ip) + ok, err := client.ChainExists("filter", ChainFilterName) + if err != nil { + return nil, fmt.Errorf("failed to check if chain exists: %s", err) + } + if !ok { + if err := client.NewChain("filter", ChainFilterName); err != nil { + return nil, fmt.Errorf("failed to create chain: %s", err) + } + } if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) { return nil, fmt.Errorf("invalid port definition") } @@ -65,8 +76,8 @@ func (m *Manager) AddFiltering( if port.IsRange { pv += ":" + strconv.Itoa(port.Values[1]) } - specs := m.filterRuleSpecs("filter", "INPUT", ip, pv, direction, action, comment) - if err := m.client(ip).AppendUnique("filter", "INPUT", specs...); err != nil { + specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment) + if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil { return nil, err } m.ruleCnt++ @@ -75,6 +86,8 @@ func (m *Manager) AddFiltering( // DeleteRule deletes a rule from the firewall func (m *Manager) DeleteRule(rule fw.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -83,22 +96,37 @@ func (m *Manager) DeleteRule(rule fw.Rule) error { if r.v6 { client = m.ipv6Client } - client.Delete("filter", chainFilterName, r.specs...) + client.Delete("filter", ChainFilterName, r.specs...) return nil } // Reset firewall to the default state func (m *Manager) Reset() error { - // clear chains from rules, if they doesn't exists create them - if err := m.ipv4Client.ClearChain("filter", chainFilterName); err != nil { - return err + m.mutex.Lock() + defer m.mutex.Unlock() + if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil { + return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err) } - if err := m.ipv6Client.ClearChain("filter", chainFilterName); err != nil { - return err + if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil { + return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err) } return nil } +func (m *Manager) reset(client *iptables.IPTables, table, chain string) error { + ok, err := client.ChainExists(table, chain) + if err != nil { + return fmt.Errorf("failed to check if chain exists: %w", err) + } + if !ok { + return nil + } + if err := client.ClearChain(table, ChainFilterName); err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + return client.DeleteChain(table, ChainFilterName) +} + // filterRuleSpecs returns the specs of a filtering rule and its id // // id builded by hashing the table, chain and specs together @@ -111,7 +139,7 @@ func (m *Manager) filterRuleSpecs( } specs = append(specs, "-p", "tcp", "--dport", port) specs = append(specs, "-j", m.action(action)) - return append(specs, "-m", comment) + return append(specs, "-m", "comment", "--comment", comment) } // client returns corresponding iptables client for the given ip diff --git a/client/firewall/iptables/manager_test.go b/client/firewall/iptables/manager_test.go new file mode 100644 index 000000000..ad87c8b3a --- /dev/null +++ b/client/firewall/iptables/manager_test.go @@ -0,0 +1,110 @@ +package iptables + +import ( + "net" + "runtime" + "testing" + + "github.com/coreos/go-iptables/iptables" + fw "github.com/netbirdio/netbird/client/firewall" +) + +func TestNewManager(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("iptables is only supported on linux") + } + + ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + t.Fatal(err) + } + + manager, err := Create() + if err != nil { + t.Fatal(err) + } + + var rule1 fw.Rule + t.Run("add first rule", func(t *testing.T) { + ip := net.ParseIP("10.20.0.2") + port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}} + rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic") + if err != nil { + t.Errorf("failed to add rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...) + }) + + var rule2 fw.Rule + t.Run("add second rule", func(t *testing.T) { + ip := net.ParseIP("10.20.0.3") + port := &fw.Port{ + Proto: fw.PortProtocolTCP, + Values: []int{8043: 8046}, + } + rule2, err = manager.AddFiltering( + ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range") + if err != nil { + t.Errorf("failed to add rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...) + }) + + t.Run("delete first rule", func(t *testing.T) { + if err := manager.DeleteRule(rule1); err != nil { + t.Errorf("failed to delete rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...) + }) + + t.Run("delete second rule", func(t *testing.T) { + if err := manager.DeleteRule(rule2); err != nil { + t.Errorf("failed to delete rule: %v", err) + } + + checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...) + }) + + t.Run("reset check", func(t *testing.T) { + // add second rule + ip := net.ParseIP("10.20.0.3") + port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}} + _, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic") + if err != nil { + t.Errorf("failed to add rule: %v", err) + } + + if err := manager.Reset(); err != nil { + t.Errorf("failed to reset: %v", err) + } + + ok, err := ipv4Client.ChainExists("filter", ChainFilterName) + if err != nil { + t.Errorf("failed to drop chain: %v", err) + } + + if ok { + t.Errorf("chain '%v' still exists after Reset", ChainFilterName) + } + }) +} + +func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) { + exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...) + if err != nil { + t.Errorf("failed to check rule: %v", err) + return + } + + if !exists && mustExists { + t.Errorf("rule '%v' does not exist", rulespec) + return + } + if exists && !mustExists { + t.Errorf("rule '%v' exist", rulespec) + return + } +}