Code cleaning in firewall package

This commit is contained in:
Zoltán Papp
2025-01-25 20:29:06 +01:00
parent 8185614362
commit efa8c17d27
42 changed files with 889 additions and 868 deletions

View File

@@ -8,13 +8,13 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (_interface.Firewall, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -11,8 +11,8 @@ import (
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/interface"
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -33,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
@@ -50,7 +50,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
@@ -63,7 +63,7 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager)
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
func createFW(iface IFaceMapper) (_interface.Firewall, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
@@ -77,7 +77,7 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm _interface.Firewall) (_interface.Firewall, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)

View File

@@ -0,0 +1,67 @@
package _interface
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Firewall is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Firewall interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
ip net.IP,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]types.Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule types.Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule types.Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair types.RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair types.RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error
// AddDNATRule adds a DNAT rule
AddDNATRule(types.ForwardRule) (types.Rule, error)
// DeleteDNATRule deletes a DNAT rule
// todo: do you need a string ID or the complete rule?
DeleteDNATRule(types.Rule) error
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -80,12 +80,12 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
func (m *aclManager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
protocol types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
@@ -107,7 +107,7 @@ func (m *aclManager) AddPeerFiltering(
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
return []types.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
@@ -152,11 +152,11 @@ func (m *aclManager) AddPeerFiltering(
m.updateState()
return []firewall.Rule{rule}, nil
return []types.Rule{rule}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
func (m *aclManager) DeletePeerRule(rule types.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@@ -354,7 +354,7 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action types.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
@@ -380,8 +380,8 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A
return append(specs, "-j", actionToStr(action))
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
func actionToStr(action types.Action) string {
if action == types.ActionAccept {
return "ACCEPT"
}
return "DROP"

View File

@@ -12,7 +12,8 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/legacy"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -97,13 +98,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
protocol types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
_ string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -113,11 +114,11 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -129,14 +130,14 @@ func (m *Manager) AddRouteFiltering(
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
func (m *Manager) DeletePeerRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -147,14 +148,14 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -162,7 +163,7 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
return legacy.SetLegacyRouter(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -200,7 +201,7 @@ func (m *Manager) AllowNetbird() error {
"all",
nil,
nil,
firewall.ActionAccept,
types.ActionAccept,
"",
"",
)
@@ -213,12 +214,12 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
return nil, fmt.Errorf("not implemented")
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
return nil
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
)
@@ -68,13 +68,13 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
var rule2 []types.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
port := &types.Port{
Values: []int{8043: 8046},
}
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -95,8 +95,8 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
port := &types.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, types.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil)
@@ -141,13 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
var rule2 []types.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
port := &types.Port{
Values: []int{443},
}
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "default", "accept HTTPS traffic from ports range")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -214,8 +214,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &types.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -42,11 +42,11 @@ const (
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
Proto types.Protocol
SPort *types.Port
DPort *types.Port
Direction types.RuleDirection
Action types.Action
SetName string
}
@@ -106,11 +106,11 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
@@ -118,7 +118,7 @@ func (r *router) AddRouteFiltering(
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
setName = types.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
@@ -146,7 +146,7 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) DeleteRouteRule(rule types.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
@@ -202,7 +202,7 @@ func (r *router) deleteIpSet(setName string) error {
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
func (r *router) AddNatRule(pair types.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -218,7 +218,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
@@ -228,12 +228,12 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
func (r *router) RemoveNatRule(pair types.RouterPair) error {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
@@ -247,8 +247,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
@@ -264,8 +264,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
@@ -293,7 +293,7 @@ func (r *router) SetLegacyManagement(isLegacy bool) {
func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
@@ -478,8 +478,8 @@ func (r *router) cleanJumpRules() error {
return nil
}
func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
func (r *router) addNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
@@ -514,8 +514,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
func (r *router) removeNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
@@ -567,7 +567,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
if params.Proto != types.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
@@ -578,7 +578,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
func applyPort(flag string, port *types.Port) []string {
if port == nil {
return nil
}

View File

@@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/firewall/types"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -54,7 +54,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{
pair := types.RouterPair{
ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.100.0/24"),
@@ -89,7 +89,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -114,8 +114,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inversePair := types.GetInversePair(testCase.InputPair)
inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -164,7 +164,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -183,8 +183,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.False(t, found, "marking rule should not exist in the manager map")
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inversePair := types.GetInversePair(testCase.InputPair)
inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -226,22 +226,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
proto types.Protocol
sPort *types.Port
dPort *types.Port
direction types.RuleDirection
action types.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
proto: types.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
dPort: &types.Port{Values: []int{80}},
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
@@ -251,77 +251,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
proto: types.ProtocolUDP,
sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
direction: types.RuleDirectionOUT,
action: types.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
proto: types.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
proto: types.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true},
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &types.Port{Values: []int{22}},
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
}
@@ -357,7 +357,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
setName := types.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)

View File

@@ -0,0 +1,35 @@
package legacy
import (
"fmt"
"github.com/sirupsen/logrus"
)
// Router defines the interface for legacy management operations
type Router interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyRouter sets the route manager to use legacy management
func SetLegacyRouter(router Router, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
logrus.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
logrus.Debugf("Legacy routing rules removed")
}
return nil
}

View File

@@ -1,196 +0,0 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
)
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// GetRuleID returns the rule id
GetRuleID() string
}
// RuleDirection is the traffic direction which a rule is applied
type RuleDirection int
const (
// RuleDirectionIN applies to filters that handlers incoming traffic
RuleDirectionIN RuleDirection = iota
// RuleDirectionOUT applies to filters that handlers outgoing traffic
RuleDirectionOUT
)
// Action is the action to be taken on a rule
type Action int
const (
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
// ActionDrop is the action to drop a packet
ActionDrop
)
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
ipsetName string,
comment string,
) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
// todo: do you need a string ID or the complete rule?
DeleteDNATRule(Rule) error
}
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
}

View File

@@ -1,192 +0,0 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
var err error
@@ -100,7 +100,7 @@ func (m *AclManager) AddPeerFiltering(
}
}
newRules := make([]firewall.Rule, 0, 2)
newRules := make([]types.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil {
return nil, err
@@ -111,7 +111,7 @@ func (m *AclManager) AddPeerFiltering(
}
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
func (m *AclManager) DeletePeerRule(rule types.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@@ -234,10 +234,10 @@ func (m *AclManager) Flush() error {
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
@@ -253,7 +253,7 @@ func (m *AclManager) addIOFiltering(
var expressions []expr.Any
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@@ -341,9 +341,9 @@ func (m *AclManager) addIOFiltering(
}
switch action {
case firewall.ActionAccept:
case types.ActionAccept:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
case types.ActionDrop:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
@@ -672,7 +672,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil
}
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
func generatePeerRuleId(ip net.IP, sPort *types.Port, dPort *types.Port, action types.Action, ipset *nftables.Set) string {
rulesetID := ":"
if sPort != nil {
rulesetID += sPort.String()
@@ -689,7 +689,7 @@ func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, a
return "set:" + ipset.Name + rulesetID
}
func encodePort(port firewall.Port) []byte {
func encodePort(port types.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
@@ -701,13 +701,13 @@ func ifname(n string) []byte {
return b
}
func protoToInt(protocol firewall.Protocol) (uint8, error) {
func protoToInt(protocol types.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
case types.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
case types.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
case types.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}

View File

@@ -13,7 +13,8 @@ import (
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/legacy"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -114,13 +115,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -135,11 +136,11 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -151,7 +152,7 @@ func (m *Manager) AddRouteFiltering(
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
func (m *Manager) DeletePeerRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -159,7 +160,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -170,14 +171,14 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -238,7 +239,7 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
return legacy.SetLegacyRouter(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -330,7 +331,7 @@ func (m *Manager) Flush() error {
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
r := &Rule{
ruleID: rule.GetRuleID(),
}
@@ -338,7 +339,7 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
return nil
}

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
)
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
rule, err := manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{53}}, types.ActionDrop, "", "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -200,8 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &types.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -283,20 +283,20 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
_, err = manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{80}}, types.ActionAccept, "", "test rule")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
types.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
&types.Port{Values: []int{443}},
types.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
pair := types.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,

View File

@@ -18,7 +18,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -167,11 +167,11 @@ func (r *router) createContainers() error {
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
@@ -200,7 +200,7 @@ func (r *router) AddRouteFiltering(
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
@@ -219,7 +219,7 @@ func (r *router) AddRouteFiltering(
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
if action == types.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
@@ -248,7 +248,7 @@ func (r *router) AddRouteFiltering(
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
setName := types.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
@@ -270,7 +270,7 @@ func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) DeleteRouteRule(rule types.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -307,7 +307,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
sources = mergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
@@ -403,7 +403,7 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
func (r *router) AddNatRule(pair types.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -420,7 +420,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
@@ -433,7 +433,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
func (r *router) addNatRule(pair types.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -494,7 +494,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
@@ -584,7 +584,7 @@ func (r *router) addPostroutingRules() error {
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -597,7 +597,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -615,8 +615,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
@@ -651,7 +651,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
@@ -829,7 +829,7 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
}
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
func (r *router) RemoveNatRule(pair types.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -838,7 +838,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
@@ -854,8 +854,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
func (r *router) removeNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
@@ -931,7 +931,7 @@ func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
func applyPort(port *types.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
@@ -987,3 +987,27 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs
}
func mergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary"
"net/netip"
"os/exec"
"reflect"
"testing"
"github.com/coreos/go-iptables/iptables"
@@ -15,8 +16,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/firewall/types"
)
const (
@@ -97,7 +98,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
@@ -139,7 +140,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "should add NAT rule")
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
@@ -209,22 +210,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
proto types.Protocol
sPort *types.Port
dPort *types.Port
direction types.RuleDirection
action types.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
proto: types.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
dPort: &types.Port{Values: []int{80}},
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
@@ -234,77 +235,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
proto: types.ProtocolUDP,
sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
direction: types.RuleDirectionOUT,
action: types.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
proto: types.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
proto: types.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true},
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &types.Port{Values: []int{22}},
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
}
@@ -441,7 +442,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
setName := types.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
@@ -506,7 +507,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort, dPort *types.Port, direction types.RuleDirection, action types.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
@@ -515,21 +516,21 @@ func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, desti
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
if direction == types.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
if direction == types.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
@@ -582,7 +583,7 @@ func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) b
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsPort(exprs []expr.Any, port *types.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
@@ -619,7 +620,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
func containsProtocol(exprs []expr.Any, proto types.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
@@ -637,13 +638,13 @@ func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
func containsAction(exprs []expr.Any, action types.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
case types.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
case types.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
@@ -714,3 +715,121 @@ func deleteWorkTable() {
}
}
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := mergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -3,7 +3,7 @@ package test
import (
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
firewall "github.com/netbirdio/netbird/client/firewall/types"
)
var (

View File

@@ -1,11 +1,10 @@
package manager
package types
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port

View File

@@ -0,0 +1,25 @@
package types
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"strings"
)
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}

View File

@@ -0,0 +1,71 @@
package types
import (
"net/netip"
"regexp"
"testing"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := GenerateSetName(prefixes1)
result2 := GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := GenerateSetName([]netip.Prefix{})
result2 := GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := GenerateSetName(prefixes1)
result2 := GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}

View File

@@ -0,0 +1,20 @@
package types
import (
"net/netip"
"sort"
)
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
}

View File

@@ -1,11 +1,10 @@
package manager
package types
import (
"strconv"
)
// Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool

View File

@@ -1,7 +1,6 @@
package manager
package types
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (

View File

@@ -1,4 +1,4 @@
package manager
package types
import (
"net/netip"

View File

@@ -0,0 +1,43 @@
package types
import "fmt"
const (
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
ForwardingFormat = "netbird-fwd-%s-%t"
ForwardingFormatPrefix = "netbird-fwd-"
)
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// GetRuleID returns the rule id
GetRuleID() string
}
// RuleDirection is the traffic direction which a rule is applied
type RuleDirection int
const (
// RuleDirectionIN applies to filters that handlers incoming traffic
RuleDirectionIN RuleDirection = iota
// RuleDirectionOUT applies to filters that handlers outgoing traffic
RuleDirectionOUT
)
// Action is the action to be taken on a rule
type Action int
const (
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
// ActionDrop is the action to drop a packet
ActionDrop
)
func GenRuleKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}

View File

@@ -13,7 +13,8 @@ import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
firewall "github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@@ -46,7 +47,7 @@ type Manager struct {
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
nativeFirewall firewall.Manager
nativeFirewall firewall.Firewall
mutex sync.RWMutex
@@ -74,7 +75,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
return create(iface)
}
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Firewall) (*Manager, error) {
mgr, err := create(iface)
if err != nil {
return nil, err
@@ -134,7 +135,7 @@ func (m *Manager) IsServerRouteSupported() bool {
}
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair types.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
@@ -142,7 +143,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
}
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
@@ -155,19 +156,19 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
_ string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
r := Rule{
id: uuid.New().String(),
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
drop: action == types.ActionDrop,
comment: comment,
}
if ipNormalized := ip.To4(); ipNormalized != nil {
@@ -188,16 +189,16 @@ func (m *Manager) AddPeerFiltering(
}
switch proto {
case firewall.ProtocolTCP:
case types.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
case types.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
case types.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
case types.ProtocolALL:
r.protoLayer = layerTypeAll
}
@@ -207,17 +208,17 @@ func (m *Manager) AddPeerFiltering(
}
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
return []types.Rule{&r}, nil
}
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error) {
if m.nativeFirewall == nil {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
@@ -225,7 +226,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
func (m *Manager) DeletePeerRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -254,12 +255,12 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
return nil, fmt.Errorf("not implemented")
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
return nil
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -90,8 +90,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.ActionAccept, "", "allow all")
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolALL, nil, nil,
types.ActionAccept, "", "allow all")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -111,10 +111,10 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}},
fw.ActionAccept, "", "explicit return")
_, err := m.AddPeerFiltering(ip, types.ProtocolTCP,
&types.Port{Values: []int{1024 + i}},
&types.Port{Values: []int{80}},
types.ActionAccept, "", "explicit return")
require.NoError(b, err)
}
},
@@ -125,8 +125,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.ActionDrop, "", "default drop")
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, nil, nil,
types.ActionDrop, "", "default drop")
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -587,10 +587,10 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -678,10 +678,10 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -796,10 +796,10 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -883,10 +883,10 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
})
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@@ -43,12 +43,12 @@ func TestManagerCreate(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
if m == nil {
t.Error("Manager is nil")
t.Error("Firewall is nil")
}
}
@@ -63,14 +63,14 @@ func TestManagerAddPeerFiltering(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
action := fw.ActionDrop
proto := types.ProtocolTCP
port := &types.Port{Values: []int{80}}
action := types.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
@@ -97,14 +97,14 @@ func TestManagerDeleteRule(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
action := fw.ActionDrop
proto := types.ProtocolTCP
port := &types.Port{Values: []int{80}}
action := types.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
@@ -138,7 +138,7 @@ func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
expDir types.RuleDirection
ip net.IP
dPort uint16
hook func([]byte) bool
@@ -147,7 +147,7 @@ func TestAddUDPPacketHook(t *testing.T) {
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
expDir: types.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1),
dPort: 8000,
hook: func([]byte) bool { return true },
@@ -155,7 +155,7 @@ func TestAddUDPPacketHook(t *testing.T) {
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
expDir: types.RuleDirectionIN,
ip: net.IPv6loopback,
dPort: 9000,
hook: func([]byte) bool { return false },
@@ -217,14 +217,14 @@ func TestManagerReset(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
action := fw.ActionDrop
proto := types.ProtocolTCP
port := &types.Port{Values: []int{80}}
action := types.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
@@ -235,7 +235,7 @@ func TestManagerReset(t *testing.T) {
err = m.Reset(nil)
if err != nil {
t.Errorf("failed to reset Manager: %v", err)
t.Errorf("failed to reset Firewall: %v", err)
return
}
@@ -251,7 +251,7 @@ func TestNotMatchByIP(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
m.wgNetwork = &net.IPNet{
@@ -260,8 +260,8 @@ func TestNotMatchByIP(t *testing.T) {
}
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
action := fw.ActionAccept
proto := types.ProtocolUDP
action := types.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
@@ -304,7 +304,7 @@ func TestNotMatchByIP(t *testing.T) {
}
if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err)
t.Errorf("failed to reset Firewall: %v", err)
return
}
}
@@ -319,7 +319,7 @@ func TestRemovePacketHook(t *testing.T) {
// creating manager instance
manager, err := Create(iface)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
t.Fatalf("Failed to create Firewall: %s", err)
}
defer func() {
require.NoError(t, manager.Reset(nil))
@@ -463,8 +463,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &types.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}