Code cleaning in firewall package

This commit is contained in:
Zoltán Papp
2025-01-25 20:29:06 +01:00
parent 8185614362
commit efa8c17d27
42 changed files with 889 additions and 868 deletions

View File

@@ -8,13 +8,13 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (_interface.Firewall, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -11,8 +11,8 @@ import (
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/interface"
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -33,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
@@ -50,7 +50,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
@@ -63,7 +63,7 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager)
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
func createFW(iface IFaceMapper) (_interface.Firewall, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
@@ -77,7 +77,7 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm _interface.Firewall) (_interface.Firewall, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)

View File

@@ -0,0 +1,67 @@
package _interface
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Firewall is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Firewall interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
ip net.IP,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]types.Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule types.Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule types.Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair types.RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair types.RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error
// AddDNATRule adds a DNAT rule
AddDNATRule(types.ForwardRule) (types.Rule, error)
// DeleteDNATRule deletes a DNAT rule
// todo: do you need a string ID or the complete rule?
DeleteDNATRule(types.Rule) error
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -80,12 +80,12 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
func (m *aclManager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
protocol types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
@@ -107,7 +107,7 @@ func (m *aclManager) AddPeerFiltering(
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
return []types.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
@@ -152,11 +152,11 @@ func (m *aclManager) AddPeerFiltering(
m.updateState()
return []firewall.Rule{rule}, nil
return []types.Rule{rule}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
func (m *aclManager) DeletePeerRule(rule types.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@@ -354,7 +354,7 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action types.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
@@ -380,8 +380,8 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A
return append(specs, "-j", actionToStr(action))
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
func actionToStr(action types.Action) string {
if action == types.ActionAccept {
return "ACCEPT"
}
return "DROP"

View File

@@ -12,7 +12,8 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/legacy"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -97,13 +98,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
protocol types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
_ string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -113,11 +114,11 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -129,14 +130,14 @@ func (m *Manager) AddRouteFiltering(
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
func (m *Manager) DeletePeerRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -147,14 +148,14 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -162,7 +163,7 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
return legacy.SetLegacyRouter(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -200,7 +201,7 @@ func (m *Manager) AllowNetbird() error {
"all",
nil,
nil,
firewall.ActionAccept,
types.ActionAccept,
"",
"",
)
@@ -213,12 +214,12 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
return nil, fmt.Errorf("not implemented")
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
return nil
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
)
@@ -68,13 +68,13 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
var rule2 []types.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
port := &types.Port{
Values: []int{8043: 8046},
}
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -95,8 +95,8 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
port := &types.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, types.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil)
@@ -141,13 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule2 []fw.Rule
var rule2 []types.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
port := &types.Port{
Values: []int{443},
}
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "default", "accept HTTPS traffic from ports range")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -214,8 +214,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &types.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -42,11 +42,11 @@ const (
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
Proto types.Protocol
SPort *types.Port
DPort *types.Port
Direction types.RuleDirection
Action types.Action
SetName string
}
@@ -106,11 +106,11 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
@@ -118,7 +118,7 @@ func (r *router) AddRouteFiltering(
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
setName = types.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
@@ -146,7 +146,7 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) DeleteRouteRule(rule types.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
@@ -202,7 +202,7 @@ func (r *router) deleteIpSet(setName string) error {
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
func (r *router) AddNatRule(pair types.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -218,7 +218,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
@@ -228,12 +228,12 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
func (r *router) RemoveNatRule(pair types.RouterPair) error {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
@@ -247,8 +247,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
@@ -264,8 +264,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
@@ -293,7 +293,7 @@ func (r *router) SetLegacyManagement(isLegacy bool) {
func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
@@ -478,8 +478,8 @@ func (r *router) cleanJumpRules() error {
return nil
}
func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
func (r *router) addNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
@@ -514,8 +514,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
func (r *router) removeNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
@@ -567,7 +567,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
if params.Proto != types.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
@@ -578,7 +578,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
func applyPort(flag string, port *types.Port) []string {
if port == nil {
return nil
}

View File

@@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/firewall/types"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -54,7 +54,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{
pair := types.RouterPair{
ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: netip.MustParsePrefix("100.100.100.0/24"),
@@ -89,7 +89,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -114,8 +114,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inversePair := types.GetInversePair(testCase.InputPair)
inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -164,7 +164,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair)
markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -183,8 +183,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.False(t, found, "marking rule should not exist in the manager map")
// Check inverse rule removal
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inversePair := types.GetInversePair(testCase.InputPair)
inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
@@ -226,22 +226,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
proto types.Protocol
sPort *types.Port
dPort *types.Port
direction types.RuleDirection
action types.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
proto: types.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
dPort: &types.Port{Values: []int{80}},
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
@@ -251,77 +251,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
proto: types.ProtocolUDP,
sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
direction: types.RuleDirectionOUT,
action: types.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
proto: types.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
proto: types.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true},
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &types.Port{Values: []int{22}},
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
}
@@ -357,7 +357,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
setName := types.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)

View File

@@ -0,0 +1,35 @@
package legacy
import (
"fmt"
"github.com/sirupsen/logrus"
)
// Router defines the interface for legacy management operations
type Router interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyRouter sets the route manager to use legacy management
func SetLegacyRouter(router Router, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
logrus.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
logrus.Debugf("Legacy routing rules removed")
}
return nil
}

View File

@@ -1,196 +0,0 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
)
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// GetRuleID returns the rule id
GetRuleID() string
}
// RuleDirection is the traffic direction which a rule is applied
type RuleDirection int
const (
// RuleDirectionIN applies to filters that handlers incoming traffic
RuleDirectionIN RuleDirection = iota
// RuleDirectionOUT applies to filters that handlers outgoing traffic
RuleDirectionOUT
)
// Action is the action to be taken on a rule
type Action int
const (
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
// ActionDrop is the action to drop a packet
ActionDrop
)
// Manager is the high level abstraction of a firewall manager
//
// It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality
type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
AddPeerFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
ipsetName string,
comment string,
) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller
Flush() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
// todo: do you need a string ID or the complete rule?
DeleteDNATRule(Rule) error
}
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
}

View File

@@ -1,192 +0,0 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -15,7 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
var err error
@@ -100,7 +100,7 @@ func (m *AclManager) AddPeerFiltering(
}
}
newRules := make([]firewall.Rule, 0, 2)
newRules := make([]types.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil {
return nil, err
@@ -111,7 +111,7 @@ func (m *AclManager) AddPeerFiltering(
}
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
func (m *AclManager) DeletePeerRule(rule types.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@@ -234,10 +234,10 @@ func (m *AclManager) Flush() error {
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
@@ -253,7 +253,7 @@ func (m *AclManager) addIOFiltering(
var expressions []expr.Any
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@@ -341,9 +341,9 @@ func (m *AclManager) addIOFiltering(
}
switch action {
case firewall.ActionAccept:
case types.ActionAccept:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
case types.ActionDrop:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
@@ -672,7 +672,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil
}
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
func generatePeerRuleId(ip net.IP, sPort *types.Port, dPort *types.Port, action types.Action, ipset *nftables.Set) string {
rulesetID := ":"
if sPort != nil {
rulesetID += sPort.String()
@@ -689,7 +689,7 @@ func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, a
return "set:" + ipset.Name + rulesetID
}
func encodePort(port firewall.Port) []byte {
func encodePort(port types.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
@@ -701,13 +701,13 @@ func ifname(n string) []byte {
return b
}
func protoToInt(protocol firewall.Protocol) (uint8, error) {
func protoToInt(protocol types.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
case types.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
case types.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
case types.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}

View File

@@ -13,7 +13,8 @@ import (
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/legacy"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -114,13 +115,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -135,11 +136,11 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -151,7 +152,7 @@ func (m *Manager) AddRouteFiltering(
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
func (m *Manager) DeletePeerRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -159,7 +160,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -170,14 +171,14 @@ func (m *Manager) IsServerRouteSupported() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddNatRule(pair)
}
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -238,7 +239,7 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
return legacy.SetLegacyRouter(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -330,7 +331,7 @@ func (m *Manager) Flush() error {
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
r := &Rule{
ruleID: rule.GetRuleID(),
}
@@ -338,7 +339,7 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
return nil
}

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
)
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
rule, err := manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{53}}, types.ActionDrop, "", "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -200,8 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &types.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -283,20 +283,20 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
_, err = manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{80}}, types.ActionAccept, "", "test rule")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
types.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
fw.ActionAccept,
&types.Port{Values: []int{443}},
types.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{
pair := types.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"),
Destination: netip.MustParsePrefix("10.0.0.0/24"),
Masquerade: true,

View File

@@ -18,7 +18,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -167,11 +167,11 @@ func (r *router) createContainers() error {
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) (types.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
@@ -200,7 +200,7 @@ func (r *router) AddRouteFiltering(
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
@@ -219,7 +219,7 @@ func (r *router) AddRouteFiltering(
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
if action == types.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
@@ -248,7 +248,7 @@ func (r *router) AddRouteFiltering(
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
setName := types.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
@@ -270,7 +270,7 @@ func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) DeleteRouteRule(rule types.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -307,7 +307,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
sources = mergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
@@ -403,7 +403,7 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
func (r *router) AddNatRule(pair types.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -420,7 +420,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
@@ -433,7 +433,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
func (r *router) addNatRule(pair types.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -494,7 +494,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
@@ -584,7 +584,7 @@ func (r *router) addPostroutingRules() error {
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
@@ -597,7 +597,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -615,8 +615,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
@@ -651,7 +651,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
@@ -829,7 +829,7 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
}
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
func (r *router) RemoveNatRule(pair types.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -838,7 +838,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
@@ -854,8 +854,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
func (r *router) removeNatRule(pair types.RouterPair) error {
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
@@ -931,7 +931,7 @@ func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
func applyPort(port *types.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
@@ -987,3 +987,27 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs
}
func mergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary"
"net/netip"
"os/exec"
"reflect"
"testing"
"github.com/coreos/go-iptables/iptables"
@@ -15,8 +16,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/firewall/types"
)
const (
@@ -97,7 +98,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair)
found := 0
for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting {
@@ -139,7 +140,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "should add NAT rule")
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
@@ -209,22 +210,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
proto types.Protocol
sPort *types.Port
dPort *types.Port
direction types.RuleDirection
action types.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
proto: types.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
dPort: &types.Port{Values: []int{80}},
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
@@ -234,77 +235,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
proto: types.ProtocolUDP,
sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
direction: types.RuleDirectionOUT,
action: types.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
proto: types.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
direction: types.RuleDirectionIN,
action: types.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
proto: types.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true},
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
proto: types.ProtocolTCP,
sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &types.Port{Values: []int{22}},
direction: types.RuleDirectionOUT,
action: types.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
proto: types.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
direction: types.RuleDirectionIN,
action: types.ActionDrop,
expectSet: false,
},
}
@@ -441,7 +442,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
setName := types.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
@@ -506,7 +507,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort, dPort *types.Port, direction types.RuleDirection, action types.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
@@ -515,21 +516,21 @@ func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, desti
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
if direction == types.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
if direction == types.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
if proto != types.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
@@ -582,7 +583,7 @@ func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) b
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
func containsPort(exprs []expr.Any, port *types.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
@@ -619,7 +620,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
func containsProtocol(exprs []expr.Any, proto types.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
@@ -637,13 +638,13 @@ func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
func containsAction(exprs []expr.Any, action types.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
case types.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
case types.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
@@ -714,3 +715,121 @@ func deleteWorkTable() {
}
}
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := mergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -3,7 +3,7 @@ package test
import (
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
firewall "github.com/netbirdio/netbird/client/firewall/types"
)
var (

View File

@@ -1,11 +1,10 @@
package manager
package types
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port

View File

@@ -0,0 +1,25 @@
package types
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"strings"
)
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}

View File

@@ -0,0 +1,71 @@
package types
import (
"net/netip"
"regexp"
"testing"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := GenerateSetName(prefixes1)
result2 := GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := GenerateSetName([]netip.Prefix{})
result2 := GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := GenerateSetName(prefixes1)
result2 := GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}

View File

@@ -0,0 +1,20 @@
package types
import (
"net/netip"
"sort"
)
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
}

View File

@@ -1,11 +1,10 @@
package manager
package types
import (
"strconv"
)
// Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool

View File

@@ -1,7 +1,6 @@
package manager
package types
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (

View File

@@ -1,4 +1,4 @@
package manager
package types
import (
"net/netip"

View File

@@ -0,0 +1,43 @@
package types
import "fmt"
const (
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t"
ForwardingFormat = "netbird-fwd-%s-%t"
ForwardingFormatPrefix = "netbird-fwd-"
)
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// GetRuleID returns the rule id
GetRuleID() string
}
// RuleDirection is the traffic direction which a rule is applied
type RuleDirection int
const (
// RuleDirectionIN applies to filters that handlers incoming traffic
RuleDirectionIN RuleDirection = iota
// RuleDirectionOUT applies to filters that handlers outgoing traffic
RuleDirectionOUT
)
// Action is the action to be taken on a rule
type Action int
const (
// ActionAccept is the action to accept a packet
ActionAccept Action = iota
// ActionDrop is the action to drop a packet
ActionDrop
)
func GenRuleKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
}

View File

@@ -13,7 +13,8 @@ import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
firewall "github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@@ -46,7 +47,7 @@ type Manager struct {
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
nativeFirewall firewall.Manager
nativeFirewall firewall.Firewall
mutex sync.RWMutex
@@ -74,7 +75,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
return create(iface)
}
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Firewall) (*Manager, error) {
mgr, err := create(iface)
if err != nil {
return nil, err
@@ -134,7 +135,7 @@ func (m *Manager) IsServerRouteSupported() bool {
}
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
func (m *Manager) AddNatRule(pair types.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
@@ -142,7 +143,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
}
// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
@@ -155,19 +156,19 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
_ string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
r := Rule{
id: uuid.New().String(),
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
drop: action == types.ActionDrop,
comment: comment,
}
if ipNormalized := ip.To4(); ipNormalized != nil {
@@ -188,16 +189,16 @@ func (m *Manager) AddPeerFiltering(
}
switch proto {
case firewall.ProtocolTCP:
case types.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
case types.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
case types.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
case types.ProtocolALL:
r.protoLayer = layerTypeAll
}
@@ -207,17 +208,17 @@ func (m *Manager) AddPeerFiltering(
}
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
return []types.Rule{&r}, nil
}
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error) {
if m.nativeFirewall == nil {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
@@ -225,7 +226,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
func (m *Manager) DeletePeerRule(rule types.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -254,12 +255,12 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
return nil, fmt.Errorf("not implemented")
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
return nil
}

View File

@@ -12,7 +12,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -90,8 +90,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.ActionAccept, "", "allow all")
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolALL, nil, nil,
types.ActionAccept, "", "allow all")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -111,10 +111,10 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []int{1024 + i}},
&fw.Port{Values: []int{80}},
fw.ActionAccept, "", "explicit return")
_, err := m.AddPeerFiltering(ip, types.ProtocolTCP,
&types.Port{Values: []int{1024 + i}},
&types.Port{Values: []int{80}},
types.ActionAccept, "", "explicit return")
require.NoError(b, err)
}
},
@@ -125,8 +125,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.ActionDrop, "", "default drop")
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, nil, nil,
types.ActionDrop, "", "default drop")
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -587,10 +587,10 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -678,10 +678,10 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -796,10 +796,10 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
@@ -883,10 +883,10 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
})
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
&types.Port{Values: []int{80}},
nil,
fw.ActionAccept, "", "return traffic")
types.ActionAccept, "", "return traffic")
require.NoError(b, err)
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@@ -43,12 +43,12 @@ func TestManagerCreate(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
if m == nil {
t.Error("Manager is nil")
t.Error("Firewall is nil")
}
}
@@ -63,14 +63,14 @@ func TestManagerAddPeerFiltering(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
action := fw.ActionDrop
proto := types.ProtocolTCP
port := &types.Port{Values: []int{80}}
action := types.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
@@ -97,14 +97,14 @@ func TestManagerDeleteRule(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
action := fw.ActionDrop
proto := types.ProtocolTCP
port := &types.Port{Values: []int{80}}
action := types.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
@@ -138,7 +138,7 @@ func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
expDir types.RuleDirection
ip net.IP
dPort uint16
hook func([]byte) bool
@@ -147,7 +147,7 @@ func TestAddUDPPacketHook(t *testing.T) {
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
expDir: types.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1),
dPort: 8000,
hook: func([]byte) bool { return true },
@@ -155,7 +155,7 @@ func TestAddUDPPacketHook(t *testing.T) {
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
expDir: types.RuleDirectionIN,
ip: net.IPv6loopback,
dPort: 9000,
hook: func([]byte) bool { return false },
@@ -217,14 +217,14 @@ func TestManagerReset(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
action := fw.ActionDrop
proto := types.ProtocolTCP
port := &types.Port{Values: []int{80}}
action := types.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
@@ -235,7 +235,7 @@ func TestManagerReset(t *testing.T) {
err = m.Reset(nil)
if err != nil {
t.Errorf("failed to reset Manager: %v", err)
t.Errorf("failed to reset Firewall: %v", err)
return
}
@@ -251,7 +251,7 @@ func TestNotMatchByIP(t *testing.T) {
m, err := Create(ifaceMock)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
t.Errorf("failed to create Firewall: %v", err)
return
}
m.wgNetwork = &net.IPNet{
@@ -260,8 +260,8 @@ func TestNotMatchByIP(t *testing.T) {
}
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
action := fw.ActionAccept
proto := types.ProtocolUDP
action := types.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
@@ -304,7 +304,7 @@ func TestNotMatchByIP(t *testing.T) {
}
if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err)
t.Errorf("failed to reset Firewall: %v", err)
return
}
}
@@ -319,7 +319,7 @@ func TestRemovePacketHook(t *testing.T) {
// creating manager instance
manager, err := Create(iface)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
t.Fatalf("Failed to create Firewall: %s", err)
}
defer func() {
require.NoError(t, manager.Reset(nil))
@@ -463,8 +463,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
port := &types.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -7,7 +7,7 @@ import (
"net/netip"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
)
type RuleID string
@@ -19,12 +19,12 @@ func (r RuleID) GetRuleID() string {
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination netip.Prefix,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,
action manager.Action,
proto types.Protocol,
sPort *types.Port,
dPort *types.Port,
action types.Action,
) RuleID {
manager.SortPrefixes(sources)
types.SortPrefixes(sources)
h := sha256.New()

View File

@@ -15,7 +15,8 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
@@ -30,17 +31,17 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
firewall _interface.Firewall
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
peerRulesPairs map[id.RuleID][]types.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
func NewDefaultManager(fm _interface.Firewall) *DefaultManager {
return &DefaultManager{
firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
peerRulesPairs: make(map[id.RuleID][]types.Rule),
routeRules: make(map[id.RuleID]struct{}),
}
}
@@ -132,7 +133,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
)
}
newRulePairs := make(map[id.RuleID][]firewall.Rule)
newRulePairs := make(map[id.RuleID][]types.Rule)
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
@@ -251,7 +252,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (id.RuleID, []firewall.Rule, error) {
) (id.RuleID, []types.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -267,13 +268,13 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
var port *types.Port
if r.Port != "" {
value, err := strconv.Atoi(r.Port)
if err != nil {
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
}
port = &firewall.Port{
port = &types.Port{
Values: []int{value},
}
}
@@ -283,7 +284,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
var rules []types.Rule
switch r.Direction {
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
@@ -304,12 +305,12 @@ func (d *DefaultManager) protoRuleToFirewallRule(
func (d *DefaultManager) addInRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
protocol types.Protocol,
port *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
@@ -320,12 +321,12 @@ func (d *DefaultManager) addInRules(
func (d *DefaultManager) addOutRules(
ip net.IP,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
protocol types.Protocol,
port *types.Port,
action types.Action,
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
) ([]types.Rule, error) {
if shouldSkipInvertedRule(protocol, port) {
return nil, nil
}
@@ -341,10 +342,10 @@ func (d *DefaultManager) addOutRules(
// getPeerRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip net.IP,
proto firewall.Protocol,
proto types.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
port *types.Port,
action types.Action,
comment string,
) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
@@ -491,7 +492,7 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]types.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
@@ -502,49 +503,49 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
}
}
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (types.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil
return types.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewall.ProtocolUDP, nil
return types.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewall.ProtocolICMP, nil
return types.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewall.ProtocolALL, nil
return types.ProtocolALL, nil
default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
return types.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) bool {
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
func shouldSkipInvertedRule(protocol types.Protocol, port *types.Port) bool {
return protocol == types.ProtocolALL || protocol == types.ProtocolICMP || port == nil
}
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
func convertFirewallAction(action mgmProto.RuleAction) (types.Action, error) {
switch action {
case mgmProto.RuleAction_ACCEPT:
return firewall.ActionAccept, nil
return types.ActionAccept, nil
case mgmProto.RuleAction_DROP:
return firewall.ActionDrop, nil
return types.ActionDrop, nil
default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
return types.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
func convertPortInfo(portInfo *mgmProto.PortInfo) *types.Port {
if portInfo == nil {
return nil
}
if portInfo.GetPort() != 0 {
return &firewall.Port{
return &types.Port{
Values: []int{int(portInfo.GetPort())},
}
}
if portInfo.GetRange() != nil {
return &firewall.Port{
return &types.Port{
IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/management/proto"
@@ -56,7 +56,7 @@ func TestDefaultManager(t *testing.T) {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
defer func(fw _interface.Firewall) {
_ = fw.Reset(nil)
}(fw)
acl := NewDefaultManager(fw)
@@ -349,7 +349,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
t.Errorf("create firewall: %v", err)
return
}
defer func(fw manager.Manager) {
defer func(fw _interface.Firewall) {
_ = fw.Reset(nil)
}(fw)
acl := NewDefaultManager(fw)

View File

@@ -9,7 +9,8 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
)
const (
@@ -19,13 +20,13 @@ const (
)
type Manager struct {
firewall firewall.Manager
firewall _interface.Firewall
fwRules []firewall.Rule
fwRules []types.Rule
dnsForwarder *DNSForwarder
}
func NewManager(fw firewall.Manager) *Manager {
func NewManager(fw _interface.Firewall) *Manager {
return &Manager{
firewall: fw,
}
@@ -79,7 +80,7 @@ func (m *Manager) Stop(ctx context.Context) error {
}
func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
dport := &types.Port{
IsRange: false,
Values: []int{ListenPort},
}
@@ -88,7 +89,7 @@ func (h *Manager) allowDNSFirewall() error {
return nil
}
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, types.ProtocolUDP, nil, dport, types.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err

View File

@@ -25,7 +25,8 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
@@ -169,7 +170,7 @@ type Engine struct {
statusRecorder *peer.Status
firewall firewallManager.Manager
firewall _interface.Firewall
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
@@ -504,15 +505,15 @@ func (e *Engine) initFirewall() error {
}
rosenpassPort := e.rpManager.GetAddress().Port
port := firewallManager.Port{Values: []int{rosenpassPort}}
port := types.Port{Values: []int{rosenpassPort}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
firewallManager.ProtocolUDP,
types.ProtocolUDP,
nil,
&port,
firewallManager.ActionAccept,
types.ActionAccept,
"",
"",
); err != nil {
@@ -540,10 +541,10 @@ func (e *Engine) blockLanAccess() {
if _, err := e.firewall.AddRouteFiltering(
[]netip.Prefix{v4},
network,
firewallManager.ProtocolALL,
types.ProtocolALL,
nil,
nil,
firewallManager.ActionDrop,
types.ActionDrop,
); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
}
@@ -1774,7 +1775,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
}
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
forwardingRules := make([]types.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := convertToFirewallProtocol(rule.GetProtocol())
if err != nil {
@@ -1800,7 +1801,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
continue
}
forwardRule := firewallManager.ForwardRule{
forwardRule := types.ForwardRule{
Protocol: proto,
DestinationPort: *dstPortInfo,
TranslatedAddress: translateIP,

View File

@@ -5,7 +5,7 @@ import (
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/internal/ingressgw"
)
@@ -14,192 +14,192 @@ func (e *Engine) mocForwardRules() {
e.ingressGatewayMgr = ingressgw.NewManager(e.firewall)
}
err := e.ingressGatewayMgr.Update(
[]firewallManager.ForwardRule{
[]types.ForwardRule{
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: false, Values: []int{10000}},
DestinationPort: types.Port{IsRange: false, Values: []int{10000}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: false, Values: []int{20000}},
TranslatedPort: types.Port{IsRange: false, Values: []int{20000}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10100, 10199}},
DestinationPort: types.Port{IsRange: true, Values: []int{10100, 10199}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20100, 20199}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20100, 20199}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10200, 10299}},
DestinationPort: types.Port{IsRange: true, Values: []int{10200, 10299}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20200, 20299}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20200, 20299}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10300, 10399}},
DestinationPort: types.Port{IsRange: true, Values: []int{10300, 10399}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20300, 20399}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20300, 20399}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10100, 10199}},
DestinationPort: types.Port{IsRange: true, Values: []int{10100, 10199}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20100, 20199}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20100, 20199}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10400, 10499}},
DestinationPort: types.Port{IsRange: true, Values: []int{10400, 10499}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20400, 20499}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20400, 20499}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10500, 10599}},
DestinationPort: types.Port{IsRange: true, Values: []int{10500, 10599}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20500, 20599}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20500, 20599}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10600, 10699}},
DestinationPort: types.Port{IsRange: true, Values: []int{10600, 10699}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20600, 20699}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20600, 20699}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10700, 10799}},
DestinationPort: types.Port{IsRange: true, Values: []int{10700, 10799}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20700, 20799}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20700, 20799}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10800, 10899}},
DestinationPort: types.Port{IsRange: true, Values: []int{10800, 10899}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20800, 20899}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20800, 20899}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10900, 10999}},
DestinationPort: types.Port{IsRange: true, Values: []int{10900, 10999}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20900, 20999}},
TranslatedPort: types.Port{IsRange: true, Values: []int{20900, 20999}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11000, 11099}},
DestinationPort: types.Port{IsRange: true, Values: []int{11000, 11099}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21000, 21099}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21000, 21099}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11100, 11199}},
DestinationPort: types.Port{IsRange: true, Values: []int{11100, 11199}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21100, 21199}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21100, 21199}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11200, 11299}},
DestinationPort: types.Port{IsRange: true, Values: []int{11200, 11299}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21200, 21299}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21200, 21299}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11300, 11399}},
DestinationPort: types.Port{IsRange: true, Values: []int{11300, 11399}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21300, 21399}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21300, 21399}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11400, 11499}},
DestinationPort: types.Port{IsRange: true, Values: []int{11400, 11499}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21400, 21499}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21400, 21499}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11500, 11599}},
DestinationPort: types.Port{IsRange: true, Values: []int{11500, 11599}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21500, 21599}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21500, 21599}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11600, 11699}},
DestinationPort: types.Port{IsRange: true, Values: []int{11600, 11699}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21600, 21699}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21600, 21699}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11700, 11799}},
DestinationPort: types.Port{IsRange: true, Values: []int{11700, 11799}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21700, 21799}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21700, 21799}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11800, 11899}},
DestinationPort: types.Port{IsRange: true, Values: []int{11800, 11899}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21800, 21899}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21800, 21899}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11900, 11999}},
DestinationPort: types.Port{IsRange: true, Values: []int{11900, 11999}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21900, 21999}},
TranslatedPort: types.Port{IsRange: true, Values: []int{21900, 21999}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12000, 12099}},
DestinationPort: types.Port{IsRange: true, Values: []int{12000, 12099}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22000, 22099}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22000, 22099}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12100, 12199}},
DestinationPort: types.Port{IsRange: true, Values: []int{12100, 12199}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22100, 22199}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22100, 22199}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12200, 12299}},
DestinationPort: types.Port{IsRange: true, Values: []int{12200, 12299}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22200, 22299}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22200, 22299}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12300, 12399}},
DestinationPort: types.Port{IsRange: true, Values: []int{12300, 12399}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22300, 22399}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22300, 22399}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12400, 12499}},
DestinationPort: types.Port{IsRange: true, Values: []int{12400, 12499}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22400, 22499}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22400, 22499}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12500, 12599}},
DestinationPort: types.Port{IsRange: true, Values: []int{12500, 12599}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22500, 22599}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22500, 22599}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12600, 12699}},
DestinationPort: types.Port{IsRange: true, Values: []int{12600, 12699}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22600, 22699}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22600, 22699}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12700, 12799}},
DestinationPort: types.Port{IsRange: true, Values: []int{12700, 12799}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22700, 22799}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22700, 22799}},
},
{
Protocol: "tcp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12800, 12899}},
DestinationPort: types.Port{IsRange: true, Values: []int{12800, 12899}},
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22800, 22899}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22800, 22899}},
},
{
Protocol: "udp",
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12900, 12999}},
DestinationPort: types.Port{IsRange: true, Values: []int{12900, 12999}},
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22900, 22999}},
TranslatedPort: types.Port{IsRange: true, Values: []int{22900, 22999}},
},
},
)

View File

@@ -8,29 +8,30 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
)
type RulePair struct {
firewallManager.ForwardRule
firewallManager.Rule
types.ForwardRule
types.Rule
}
type Manager struct {
firewallManager firewallManager.Manager
firewall _interface.Firewall
rules map[string]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(firewall firewallManager.Manager) *Manager {
func NewManager(firewall _interface.Firewall) *Manager {
return &Manager{
firewallManager: firewall,
rules: make(map[string]RulePair),
firewall: firewall,
rules: make(map[string]RulePair),
}
}
func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error {
func (h *Manager) Update(forwardRules []types.ForwardRule) error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
@@ -48,7 +49,7 @@ func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error {
continue
}
rule, err := h.firewallManager.AddDNATRule(fwdRule)
rule, err := h.firewall.AddDNATRule(fwdRule)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to add forward rule '%s': %v", fwdRule.String(), err))
continue
@@ -62,7 +63,7 @@ func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error {
// Remove deleted rules
for id, rulePair := range toDelete {
if err := h.firewallManager.DeleteDNATRule(rulePair.Rule); err != nil {
if err := h.firewall.DeleteDNATRule(rulePair.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
}
delete(h.rules, id)
@@ -78,18 +79,18 @@ func (h *Manager) Close() error {
log.Infof("clean up all forward rules (%d)", len(h.rules))
var mErr *multierror.Error
for _, rule := range h.rules {
if err := h.firewallManager.DeleteDNATRule(rule.Rule); err != nil {
if err := h.firewall.DeleteDNATRule(rule.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
}
}
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Rules() []firewallManager.ForwardRule {
func (h *Manager) Rules() []types.ForwardRule {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
rules := make([]firewallManager.ForwardRule, 0, len(h.rules))
rules := make([]types.ForwardRule, 0, len(h.rules))
for _, rulePair := range h.rules {
rules = append(rules, rulePair.ForwardRule)
}

View File

@@ -6,39 +6,39 @@ import (
"net"
"net/netip"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/types"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (types.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewallManager.ProtocolTCP, nil
return types.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewallManager.ProtocolUDP, nil
return types.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewallManager.ProtocolICMP, nil
return types.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewallManager.ProtocolALL, nil
return types.ProtocolALL, nil
default:
return firewallManager.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
return types.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
// convertPortInfo todo: write validation for portInfo
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewallManager.Port {
func convertPortInfo(portInfo *mgmProto.PortInfo) *types.Port {
if portInfo == nil {
return nil
}
if portInfo.GetPort() != 0 {
return &firewallManager.Port{
return &types.Port{
Values: []int{int(portInfo.GetPort())},
}
}
if portInfo.GetRange() != nil {
return &firewallManager.Port{
return &types.Port{
IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
}

View File

@@ -11,7 +11,7 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
firewall "github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
@@ -44,7 +44,7 @@ type Manager interface {
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
EnableServerRouter(firewall _interface.Firewall) error
Stop(stateManager *statemanager.Manager)
}
@@ -214,7 +214,7 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
return routeselector.NewRouteSelector()
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
func (m *DefaultManager) EnableServerRouter(firewall _interface.Firewall) error {
if m.disableServerRoutes {
log.Info("server routes are disabled")
return nil

View File

@@ -3,7 +3,7 @@ package routemanager
import (
"context"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/routeselector"
@@ -78,7 +78,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
}
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
func (m *MockManager) EnableServerRouter(firewall _interface.Firewall) error {
panic("implement me")
}

View File

@@ -6,7 +6,7 @@ import (
"context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
@@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
return nil
}
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
func newServerRouter(context.Context, iface.IWGIface, _interface.Firewall, *peer.Status) (*serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -10,7 +10,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/interface"
"github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
@@ -21,12 +22,12 @@ type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[route.ID]*route.Route
firewall firewall.Manager
firewall _interface.Firewall
wgInterface iface.IWGIface
statusRecorder *peer.Status
}
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall _interface.Firewall, statusRecorder *peer.Status) (*serverRouter, error) {
return &serverRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
@@ -167,7 +168,7 @@ func (m *serverRouter) cleanUp() {
m.statusRecorder.UpdateLocalPeerState(state)
}
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
func routeToRouterPair(route *route.Route) (types.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
@@ -177,7 +178,7 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
destination = getDefaultPrefix(destination)
}
return firewall.RouterPair{
return types.RouterPair{
ID: route.ID,
Source: source,
Destination: destination,

View File

@@ -3,7 +3,7 @@ package server
import (
"context"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
firewall "github.com/netbirdio/netbird/client/firewall/types"
"github.com/netbirdio/netbird/client/proto"
)