Add default firewall rule to allow netbird traffic (#1056)

Add a default firewall rule to allow netbird traffic to be handled 
by the access control managers.

Userspace manager behavior:
- When running on Windows, a default rule is add on Windows firewall
- For Linux, we are using one of the Kernel managers to add a single rule
- This PR doesn't handle macOS

Kernel manager behavior:
- For NFtables, if there is a filter table, an INPUT rule is added
- Iptables follows the previous flow if running on kernel mode. If running 
on userspace mode, it adds a single rule for INPUT and OUTPUT chains

A new checkerFW package has been introduced to consolidate checks across
route and access control managers.
It supports a new environment variable to skip nftables and allow iptables tests
This commit is contained in:
Givi Khojanashvili
2023-09-05 23:07:32 +04:00
committed by GitHub
parent e4bc76c4de
commit 246abda46d
24 changed files with 568 additions and 153 deletions

View File

@@ -40,6 +40,9 @@ const (
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set

View File

@@ -44,6 +44,7 @@ type Manager struct {
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
type ruleset struct {
@@ -52,7 +53,7 @@ type ruleset struct {
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
m := &Manager{
wgIface: wgIface,
inputDefaultRuleSpecs: []string{
@@ -62,26 +63,26 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
rulesets: make(map[string]ruleset),
}
if err := ipset.Init(); err != nil {
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
// init clients for booth ipv4 and ipv6
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
if isIptablesClientAvailable(ipv4Client) {
m.ipv4Client = ipv4Client
if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
}
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
} else {
if isIptablesClientAvailable(ipv6Client) {
m.ipv6Client = ipv6Client
}
if m.ipv4Client == nil && m.ipv6Client == nil {
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
}
if err := m.Reset(); err != nil {
@@ -90,11 +91,6 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
// AddFiltering rule to the firewall
//
// If comment is empty rule ID is used as comment
@@ -276,6 +272,38 @@ func (m *Manager) Reset() error {
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"allow netbird interface traffic",
)
return err
}
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@@ -406,7 +434,7 @@ func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
}
if err := client.AppendUnique("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
}

View File

@@ -33,6 +33,8 @@ func (i *iFaceMock) Address() iface.WGAddress {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
@@ -53,7 +55,7 @@ func TestIptablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -141,7 +143,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -229,7 +231,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, true)
require.NoError(t, err)
time.Sleep(time.Second)

View File

@@ -29,6 +29,8 @@ const (
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
@@ -379,7 +381,7 @@ func (m *Manager) chain(
if c != nil {
return c, nil
}
return m.createChainIfNotExists(tf, name, hook, priority, cType)
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
}
if ip.To4() != nil {
@@ -399,13 +401,20 @@ func (m *Manager) chain(
}
// table returns the table for the given family of the IP address
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil {
return nil, err
}
@@ -417,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil {
return nil, err
}
@@ -425,19 +434,21 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
return m.tableIPv6, nil
}
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == FilterTableName {
if t.Name == tableName {
return t, nil
}
}
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
@@ -446,12 +457,13 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
tableName string,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family)
table, err := m.table(family, tableName)
if err != nil {
return nil, err
}
@@ -638,6 +650,22 @@ func (m *Manager) Reset() error {
return fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c)
}
@@ -702,6 +730,53 @@ func (m *Manager) Flush() error {
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skiping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
}
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
@@ -745,6 +820,44 @@ func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chai
return nil
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func encodePort(port fw.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))

View File

@@ -0,0 +1,19 @@
//go:build !windows && !linux
package uspfilter
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}

View File

@@ -0,0 +1,21 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@@ -0,0 +1,91 @@
package uspfilter
import (
"errors"
"fmt"
"os/exec"
"strings"
"syscall"
)
type action string
const (
addRule action = "add"
deleteRule action = "delete"
firewallRuleName = "Netbird"
noRulesMatchCriteria = "No rules match the specified criteria"
)
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
return fmt.Errorf("couldn't remove windows firewall: %w", err)
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return manageFirewallRule(firewallRuleName,
addRule,
"dir=in",
"enable=yes",
"action=allow",
"profile=any",
"localip="+m.wgIface.Address().IP.String(),
)
}
func manageFirewallRule(ruleName string, action action, args ...string) error {
active, err := isFirewallRuleActive(ruleName)
if err != nil {
return err
}
if (action == addRule && !active) || (action == deleteRule && active) {
baseArgs := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
args := append(baseArgs, args...)
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run()
}
return nil
}
func isFirewallRuleActive(ruleName string) (bool, error) {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
output, err := cmd.Output()
if err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// if the firewall rule is not active, we expect last exit code to be 1
exitStatus := exitError.Sys().(syscall.WaitStatus).ExitStatus()
if exitStatus == 1 {
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
}
}
return false, err
}
if strings.Contains(string(output), noRulesMatchCriteria) {
return false, nil
}
return true, nil
}

View File

@@ -19,6 +19,7 @@ const layerTypeAll = 0
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
Address() iface.WGAddress
}
// RuleSet is a set of rules grouped by a string key
@@ -30,6 +31,8 @@ type Manager struct {
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
mutex sync.RWMutex
}
@@ -65,6 +68,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
},
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
}
if err := iface.SetFilter(m); err != nil {
@@ -171,17 +175,6 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
return nil
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@@ -375,3 +368,8 @@ func (m *Manager) RemovePacketHook(hookID string) error {
}
return fmt.Errorf("hook with given id not found")
}
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@@ -16,6 +16,7 @@ import (
type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress
}
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
@@ -25,6 +26,13 @@ func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil {
return iface.WGAddress{}
}
return i.AddressFunc()
}
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil },