mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Feat linux firewall support (#805)
Update the client's engine to apply firewall rules received from the manager (results of ACL policy).
This commit is contained in:
committed by
GitHub
parent
2eb9a97fee
commit
ba7a39a4fc
@@ -13,22 +13,24 @@ type Rule interface {
|
||||
GetRuleID() string
|
||||
}
|
||||
|
||||
// Direction is the direction of the traffic
|
||||
type Direction int
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
type RuleDirection int
|
||||
|
||||
const (
|
||||
// DirectionSrc is the direction of the traffic from the source
|
||||
DirectionSrc Direction = iota
|
||||
// DirectionDst is the direction of the traffic from the destination
|
||||
DirectionDst
|
||||
// RuleDirectionIN applies to filters that handlers incoming traffic
|
||||
RuleDirectionIN RuleDirection = iota
|
||||
// RuleDirectionOUT applies to filters that handlers outgoing traffic
|
||||
RuleDirectionOUT
|
||||
)
|
||||
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
const (
|
||||
// ActionUnknown is a unknown action
|
||||
ActionUnknown Action = iota
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
ActionAccept
|
||||
// ActionDrop is the action to drop a packet
|
||||
ActionDrop
|
||||
)
|
||||
@@ -39,10 +41,15 @@ const (
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddFiltering(
|
||||
ip net.IP,
|
||||
port *Port,
|
||||
direction Direction,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
|
||||
@@ -8,26 +8,43 @@ import (
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
const (
|
||||
// ChainFilterName is the name of the chain that is used for filtering by the Netbird client
|
||||
ChainFilterName = "NETBIRD-ACL"
|
||||
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
|
||||
ChainInputFilterName = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
|
||||
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
|
||||
)
|
||||
|
||||
// jumpNetbirdInputDefaultRule always added by manager to the input chain for all trafic from the Netbird interface
|
||||
var jumpNetbirdInputDefaultRule = []string{"-j", ChainInputFilterName}
|
||||
|
||||
// jumpNetbirdOutputDefaultRule always added by manager to the output chain for all trafic from the Netbird interface
|
||||
var jumpNetbirdOutputDefaultRule = []string{"-j", ChainOutputFilterName}
|
||||
|
||||
// dropAllDefaultRule in the Netbird chain
|
||||
var dropAllDefaultRule = []string{"-j", "DROP"}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
ipv6Client *iptables.IPTables
|
||||
|
||||
wgIfaceName string
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create() (*Manager, error) {
|
||||
m := &Manager{}
|
||||
func Create(wgIfaceName string) (*Manager, error) {
|
||||
m := &Manager{
|
||||
wgIfaceName: wgIfaceName,
|
||||
}
|
||||
|
||||
// init clients for booth ipv4 and ipv6
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
@@ -38,118 +55,266 @@ func Create() (*Manager, error) {
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ip6tables is not installed in the system or not supported")
|
||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||
} else {
|
||||
m.ipv6Client = ipv6Client
|
||||
}
|
||||
m.ipv6Client = ipv6Client
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset firewall: %s", err)
|
||||
return nil, fmt.Errorf("failed to reset firewall: %v", err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
port *fw.Port,
|
||||
direction fw.Direction,
|
||||
protocol fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
client := m.client(ip)
|
||||
ok, err := client.ChainExists("filter", ChainFilterName)
|
||||
|
||||
client, err := m.client(ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %s", err)
|
||||
}
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create chain: %s", err)
|
||||
}
|
||||
}
|
||||
if port == nil || port.Values == nil || (port.IsRange && len(port.Values) != 2) {
|
||||
return nil, fmt.Errorf("invalid port definition")
|
||||
}
|
||||
pv := strconv.Itoa(port.Values[0])
|
||||
if port.IsRange {
|
||||
pv += ":" + strconv.Itoa(port.Values[1])
|
||||
}
|
||||
specs := m.filterRuleSpecs("filter", ChainFilterName, ip, pv, direction, action, comment)
|
||||
if err := client.AppendUnique("filter", ChainFilterName, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rule := &Rule{
|
||||
id: uuid.New().String(),
|
||||
specs: specs,
|
||||
v6: ip.To4() == nil,
|
||||
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
||||
}
|
||||
return rule, nil
|
||||
if sPort != nil && sPort.Values != nil {
|
||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
||||
}
|
||||
|
||||
ruleID := uuid.New().String()
|
||||
if comment == "" {
|
||||
comment = ruleID
|
||||
}
|
||||
|
||||
specs := m.filterRuleSpecs(
|
||||
"filter",
|
||||
ip,
|
||||
string(protocol),
|
||||
sPortVal,
|
||||
dPortVal,
|
||||
direction,
|
||||
action,
|
||||
comment,
|
||||
)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is output rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("input rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check is input rule already exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("input rule already exists")
|
||||
}
|
||||
|
||||
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Rule{
|
||||
id: ruleID,
|
||||
specs: specs,
|
||||
dst: direction == fw.RuleDirectionOUT,
|
||||
v6: ip.To4() == nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
client := m.ipv4Client
|
||||
if r.v6 {
|
||||
if m.ipv6Client == nil {
|
||||
return fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
client = m.ipv6Client
|
||||
}
|
||||
return client.Delete("filter", ChainFilterName, r.specs...)
|
||||
|
||||
if r.dst {
|
||||
return client.Delete("filter", ChainOutputFilterName, r.specs...)
|
||||
}
|
||||
return client.Delete("filter", ChainInputFilterName, r.specs...)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if err := m.reset(m.ipv4Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv4 firewall ACL chain: %w", err)
|
||||
|
||||
if err := m.reset(m.ipv4Client, "filter"); err != nil {
|
||||
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
|
||||
}
|
||||
if err := m.reset(m.ipv6Client, "filter", ChainFilterName); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL chain: %w", err)
|
||||
if m.ipv6Client != nil {
|
||||
if err := m.reset(m.ipv6Client, "filter"); err != nil {
|
||||
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table, chain string) error {
|
||||
ok, err := client.ChainExists(table, chain)
|
||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
return fmt.Errorf("failed to check if input chain exists: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
if ok {
|
||||
specs := append([]string{"-i", m.wgIfaceName}, jumpNetbirdInputDefaultRule...)
|
||||
if ok, err := client.Exists("filter", "INPUT", specs...); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
if err := client.Delete("filter", "INPUT", specs...); err != nil {
|
||||
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = client.ChainExists(table, ChainOutputFilterName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if output chain exists: %w", err)
|
||||
}
|
||||
if ok {
|
||||
specs := append([]string{"-o", m.wgIfaceName}, jumpNetbirdOutputDefaultRule...)
|
||||
if ok, err := client.Exists("filter", "OUTPUT", specs...); err != nil {
|
||||
return err
|
||||
} else if ok {
|
||||
if err := client.Delete("filter", "OUTPUT", specs...); err != nil {
|
||||
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
|
||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
||||
return nil
|
||||
}
|
||||
if err := client.ClearChain(table, ChainFilterName); err != nil {
|
||||
return fmt.Errorf("failed to clear chain: %w", err)
|
||||
|
||||
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
|
||||
log.Errorf("failed to clear and delete input chain: %v", err)
|
||||
return nil
|
||||
}
|
||||
return client.DeleteChain(table, ChainFilterName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func (m *Manager) filterRuleSpecs(
|
||||
table string, chain string, ip net.IP, port string,
|
||||
direction fw.Direction, action fw.Action, comment string,
|
||||
table string, ip net.IP, protocol string, sPort, dPort string,
|
||||
direction fw.RuleDirection, action fw.Action, comment string,
|
||||
) (specs []string) {
|
||||
if direction == fw.DirectionSrc {
|
||||
switch direction {
|
||||
case fw.RuleDirectionIN:
|
||||
specs = append(specs, "-s", ip.String())
|
||||
case fw.RuleDirectionOUT:
|
||||
specs = append(specs, "-d", ip.String())
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
if sPort != "" {
|
||||
specs = append(specs, "--sport", sPort)
|
||||
}
|
||||
if dPort != "" {
|
||||
specs = append(specs, "--dport", dPort)
|
||||
}
|
||||
specs = append(specs, "-p", "tcp", "--dport", port)
|
||||
specs = append(specs, "-j", m.actionToStr(action))
|
||||
return append(specs, "-m", "comment", "--comment", comment)
|
||||
}
|
||||
|
||||
// client returns corresponding iptables client for the given ip
|
||||
func (m *Manager) client(ip net.IP) *iptables.IPTables {
|
||||
// rawClient returns corresponding iptables client for the given ip
|
||||
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
|
||||
if ip.To4() != nil {
|
||||
return m.ipv4Client
|
||||
return m.ipv4Client, nil
|
||||
}
|
||||
return m.ipv6Client
|
||||
if m.ipv6Client == nil {
|
||||
return nil, fmt.Errorf("ipv6 is not supported")
|
||||
}
|
||||
return m.ipv6Client, nil
|
||||
}
|
||||
|
||||
// client returns client with initialized chain and default rules
|
||||
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
|
||||
client, err := m.rawClient(ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ok, err := client.ChainExists("filter", ChainInputFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
|
||||
}
|
||||
|
||||
specs := append([]string{"-i", m.wgIfaceName}, jumpNetbirdInputDefaultRule...)
|
||||
if err := client.AppendUnique("filter", "INPUT", specs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
ok, err = client.ChainExists("filter", ChainOutputFilterName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output chain: %w", err)
|
||||
}
|
||||
|
||||
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
|
||||
}
|
||||
|
||||
specs := append([]string{"-o", m.wgIfaceName}, jumpNetbirdOutputDefaultRule...)
|
||||
if err := client.AppendUnique("filter", "OUTPUT", specs...); err != nil {
|
||||
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (m *Manager) actionToStr(action fw.Action) string {
|
||||
|
||||
@@ -1,105 +1,129 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
func TestIptablesManager(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
manager, err := Create()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// just check on the local interface
|
||||
manager, err := Create("lo")
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule1 fw.Rule
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Proto: fw.PortProtocolTCP, Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTP traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, true, rule1.(*Rule).specs...)
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
})
|
||||
|
||||
var rule2 fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Proto: fw.PortProtocolTCP,
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, port, fw.DirectionDst, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, true, rule2.(*Rule).specs...)
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete first rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule1); err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, false, rule1.(*Rule).specs...)
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
if err := manager.DeleteRule(rule2); err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, false, rule2.(*Rule).specs...)
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, false, rule2.(*Rule).specs...)
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Proto: fw.PortProtocolUDP, Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, port, fw.DirectionDst, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
if err != nil {
|
||||
t.Errorf("failed to add rule: %v", err)
|
||||
}
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("failed to reset: %v", err)
|
||||
}
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", ChainFilterName)
|
||||
if err != nil {
|
||||
t.Errorf("failed to drop chain: %v", err)
|
||||
}
|
||||
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
t.Errorf("chain '%v' still exists after Reset", ChainFilterName)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, mustExists bool, rulespec ...string) {
|
||||
exists, err := ipv4Client.Exists("filter", ChainFilterName, rulespec...)
|
||||
if err != nil {
|
||||
t.Errorf("failed to check rule: %v", err)
|
||||
return
|
||||
}
|
||||
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
|
||||
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
|
||||
require.NoError(t, err, "failed to check rule")
|
||||
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
|
||||
require.Falsef(t, exists && !mustExists, "rule '%v' exist", rulespec)
|
||||
}
|
||||
|
||||
if !exists && mustExists {
|
||||
t.Errorf("rule '%v' does not exist", rulespec)
|
||||
return
|
||||
}
|
||||
if exists && !mustExists {
|
||||
t.Errorf("rule '%v' exist", rulespec)
|
||||
return
|
||||
func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create("lo")
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
_, err = manager.client(net.ParseIP("10.20.0.100"))
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ package iptables
|
||||
type Rule struct {
|
||||
id string
|
||||
specs []string
|
||||
dst bool
|
||||
v6 bool
|
||||
}
|
||||
|
||||
|
||||
435
client/firewall/nftables/manager_linux.go
Normal file
435
client/firewall/nftables/manager_linux.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
const (
|
||||
// FilterTableName is the name of the table that is used for filtering by the Netbird client
|
||||
FilterTableName = "netbird-acl"
|
||||
|
||||
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
|
||||
FilterInputChainName = "netbird-acl-input-filter"
|
||||
|
||||
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
)
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
|
||||
filterInputChainIPv4 *nftables.Chain
|
||||
filterOutputChainIPv4 *nftables.Chain
|
||||
|
||||
filterInputChainIPv6 *nftables.Chain
|
||||
filterOutputChainIPv6 *nftables.Chain
|
||||
|
||||
wgIfaceName string
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIfaceName string) (*Manager, error) {
|
||||
m := &Manager{
|
||||
conn: &nftables.Conn{},
|
||||
wgIfaceName: wgIfaceName,
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var (
|
||||
err error
|
||||
table *nftables.Table
|
||||
chain *nftables.Chain
|
||||
)
|
||||
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
FilterOutputChainName,
|
||||
nftables.ChainHookOutput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
} else {
|
||||
table, chain, err = m.chain(
|
||||
ip,
|
||||
FilterInputChainName,
|
||||
nftables.ChainHookInput,
|
||||
nftables.ChainPriorityFilter,
|
||||
nftables.ChainTypeFilter)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIfaceName),
|
||||
},
|
||||
}
|
||||
|
||||
if proto != "all" {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
var protoData []byte
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
protoData = []byte{unix.IPPROTO_TCP}
|
||||
case fw.ProtocolUDP:
|
||||
protoData = []byte{unix.IPPROTO_UDP}
|
||||
case fw.ProtocolICMP:
|
||||
protoData = []byte{unix.IPPROTO_ICMP}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: protoData,
|
||||
})
|
||||
}
|
||||
|
||||
// source address position
|
||||
var adrLen, adrOffset uint32
|
||||
if ip.To4() == nil {
|
||||
adrLen = 16
|
||||
adrOffset = 8
|
||||
} else {
|
||||
adrLen = 4
|
||||
adrOffset = 12
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
adrOffset += adrLen
|
||||
}
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: adrOffset,
|
||||
Len: adrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
)
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 0,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*sPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) != 0 {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: encodePort(*dPort),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if action == fw.ActionAccept {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
} else {
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
||||
|
||||
_ = m.conn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list, err := m.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add the rule to the chain
|
||||
rule := &Rule{id: id}
|
||||
for _, r := range list {
|
||||
if bytes.Equal(r.UserData, userData) {
|
||||
rule.Rule = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if rule.Rule == nil {
|
||||
return nil, fmt.Errorf("rule not found")
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// chain returns the chain for the given IP address with specific settings
|
||||
func (m *Manager) chain(
|
||||
ip net.IP,
|
||||
name string,
|
||||
hook nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
cType nftables.ChainType,
|
||||
) (*nftables.Table, *nftables.Chain, error) {
|
||||
var err error
|
||||
|
||||
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
return m.createChainIfNotExists(tf, name, hook, priority, cType)
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if name == FilterInputChainName {
|
||||
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
|
||||
return m.tableIPv4, m.filterInputChainIPv4, err
|
||||
}
|
||||
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
|
||||
return m.tableIPv4, m.filterOutputChainIPv4, err
|
||||
}
|
||||
if name == FilterInputChainName {
|
||||
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
|
||||
return m.tableIPv4, m.filterInputChainIPv6, err
|
||||
}
|
||||
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
|
||||
return m.tableIPv4, m.filterOutputChainIPv6, err
|
||||
}
|
||||
|
||||
// table returns the table for the given family of the IP address
|
||||
func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
if m.tableIPv4 != nil {
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.tableIPv4 = table
|
||||
return m.tableIPv4, nil
|
||||
}
|
||||
|
||||
if m.tableIPv6 != nil {
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.tableIPv6 = table
|
||||
return m.tableIPv6, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.conn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
||||
}
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
family nftables.TableFamily,
|
||||
name string,
|
||||
hooknum nftables.ChainHook,
|
||||
priority nftables.ChainPriority,
|
||||
chainType nftables.ChainType,
|
||||
) (*nftables.Chain, error) {
|
||||
table, err := m.table(family)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
for _, c := range chains {
|
||||
if c.Name == name && c.Table.Name == table.Name {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Hooknum: hooknum,
|
||||
Priority: priority,
|
||||
Type: chainType,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.conn.AddChain(chain)
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if name == FilterOutputChainName {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
}
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIfaceName),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return chain, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
nativeRule, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.conn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
for _, c := range chains {
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.conn.DelChain(c)
|
||||
}
|
||||
}
|
||||
|
||||
tables, err := m.conn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
m.conn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
return bs
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, []byte(n+"\x00"))
|
||||
return b
|
||||
}
|
||||
137
client/firewall/nftables/manager_linux_test.go
Normal file
137
client/firewall/nftables/manager_linux_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create("lo")
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{53}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// 1 regular rule and other "drop all rule" for the interface
|
||||
require.Len(t, rules, 2, "expected 1 rule")
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
expectedExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: uint32(9),
|
||||
Len: uint32(1),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 53},
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
|
||||
|
||||
err = manager.DeleteRule(rule)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
require.Len(t, rules, 1, "expected 1 rules after deleteion")
|
||||
|
||||
err = manager.Reset()
|
||||
require.NoError(t, err, "failed to reset")
|
||||
}
|
||||
|
||||
func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create("lo")
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||
})
|
||||
}
|
||||
}
|
||||
16
client/firewall/nftables/rule_linux.go
Normal file
16
client/firewall/nftables/rule_linux.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
*nftables.Rule
|
||||
id string
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
@@ -1,14 +1,23 @@
|
||||
package firewall
|
||||
|
||||
// PortProtocol is the protocol of the port
|
||||
type PortProtocol string
|
||||
// Protocol is the protocol of the port
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
// PortProtocolTCP is the TCP protocol
|
||||
PortProtocolTCP PortProtocol = "tcp"
|
||||
// ProtocolTCP is the TCP protocol
|
||||
ProtocolTCP Protocol = "tcp"
|
||||
|
||||
// PortProtocolUDP is the UDP protocol
|
||||
PortProtocolUDP PortProtocol = "udp"
|
||||
// ProtocolUDP is the UDP protocol
|
||||
ProtocolUDP Protocol = "udp"
|
||||
|
||||
// ProtocolICMP is the ICMP protocol
|
||||
ProtocolICMP Protocol = "icmp"
|
||||
|
||||
// ProtocolALL cover all supported protocols
|
||||
ProtocolALL Protocol = "all"
|
||||
|
||||
// ProtocolUnknown unknown protocol
|
||||
ProtocolUnknown Protocol = "unknown"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
@@ -18,7 +27,4 @@ type Port struct {
|
||||
|
||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||
Values []int
|
||||
|
||||
// Proto is the protocol of the port
|
||||
Proto PortProtocol
|
||||
}
|
||||
|
||||
27
client/firewall/uspfilter/rule.go
Normal file
27
client/firewall/uspfilter/rule.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
id string
|
||||
ip net.IP
|
||||
ipLayer gopacket.LayerType
|
||||
protoLayer gopacket.LayerType
|
||||
direction fw.RuleDirection
|
||||
sPort uint16
|
||||
dPort uint16
|
||||
drop bool
|
||||
comment string
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
}
|
||||
291
client/firewall/uspfilter/uspfilter.go
Normal file
291
client/firewall/uspfilter/uspfilter.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
SetFiltering(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules []Rule
|
||||
incomingRules []Rule
|
||||
rulesIndex map[string]int
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
type decoder struct {
|
||||
eth layers.Ethernet
|
||||
ip4 layers.IPv4
|
||||
ip6 layers.IPv6
|
||||
tcp layers.TCP
|
||||
udp layers.UDP
|
||||
icmp4 layers.ICMPv4
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rulesIndex: make(map[string]int),
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
return d
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := iface.SetFiltering(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddFiltering(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
r := Rule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
direction: direction,
|
||||
drop: action == fw.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
r.ip = ipNormalized
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) == 1 {
|
||||
r.sPort = uint16(sPort.Values[0])
|
||||
}
|
||||
|
||||
if dPort != nil && len(dPort.Values) == 1 {
|
||||
r.dPort = uint16(dPort.Values[0])
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case fw.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case fw.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case fw.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case fw.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var p int
|
||||
if direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules, r)
|
||||
p = len(m.incomingRules) - 1
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules, r)
|
||||
p = len(m.outgoingRules) - 1
|
||||
}
|
||||
m.rulesIndex[r.id] = p
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
p, ok := m.rulesIndex[r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.rulesIndex, r.id)
|
||||
|
||||
var toUpdate []Rule
|
||||
if r.direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
||||
toUpdate = m.incomingRules
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
||||
toUpdate = m.outgoingRules
|
||||
}
|
||||
|
||||
for i := 0; i < len(toUpdate); i++ {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = m.outgoingRules[:0]
|
||||
m.incomingRules = m.incomingRules[:0]
|
||||
m.rulesIndex = make(map[string]int)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||
}
|
||||
|
||||
// DropIncoming filter incoming packets
|
||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.incomingRules, true)
|
||||
}
|
||||
|
||||
// dropFilter imlements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
log.Tracef("couldn't decode layer, err: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
log.Tracef("not enough levels in network packet")
|
||||
return true
|
||||
}
|
||||
|
||||
ipLayer := d.decoded[0]
|
||||
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
|
||||
return false
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
// check if IP address match by IP
|
||||
for _, rule := range rules {
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.drop
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
continue
|
||||
}
|
||||
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||
return rule.drop
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||
return rule.drop
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||
return rule.drop
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop
|
||||
}
|
||||
return rule.drop
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop
|
||||
}
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
||||
m.wgNetwork = network
|
||||
}
|
||||
207
client/firewall/uspfilter/uspfilter_test.go
Normal file
207
client/firewall/uspfilter/uspfilter_test.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type IFaceMock struct {
|
||||
SetFilteringFunc func(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
func (i *IFaceMock) SetFiltering(iface iface.PacketFilter) error {
|
||||
if i.SetFilteringFunc == nil {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
return i.SetFilteringFunc(iface)
|
||||
}
|
||||
|
||||
func TestManagerCreate(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
t.Error("Manager is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddFiltering(t *testing.T) {
|
||||
isSetFilteringCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilteringFunc: func(iface.PacketFilter) error {
|
||||
isSetFilteringCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if rule == nil {
|
||||
t.Error("Rule is nil")
|
||||
return
|
||||
}
|
||||
|
||||
if !isSetFilteringCalled {
|
||||
t.Error("SetFiltering was not called")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerDeleteRule(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip = net.ParseIP("192.168.1.1")
|
||||
proto = fw.ProtocolTCP
|
||||
port = &fw.Port{Values: []int{80}}
|
||||
direction = fw.RuleDirectionIN
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
||||
t.Errorf("rule2 is not in the rulesIndex")
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rule1 still in the rulesIndex")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerReset(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
direction := fw.RuleDirectionOUT
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = m.Reset()
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rules is not empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilteringFunc: func(iface.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
t.Errorf("clear the manager state: %v", err)
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user