mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Code cleaning in firewall package
This commit is contained in:
@@ -8,13 +8,13 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (_interface.Firewall, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -33,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
@@ -50,7 +50,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
||||
return createUserspaceFirewall(iface, fm)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) {
|
||||
fm, err := createFW(iface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
@@ -63,7 +63,7 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager)
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
func createFW(iface IFaceMapper) (_interface.Firewall, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Info("creating an iptables firewall manager")
|
||||
@@ -77,7 +77,7 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm _interface.Firewall) (_interface.Firewall, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
||||
67
client/firewall/interface/firewall.go
Normal file
67
client/firewall/interface/firewall.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package _interface
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Firewall is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Firewall interface {
|
||||
Init(stateManager *statemanager.Manager) error
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]types.Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule types.Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule types.Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair types.RouterPair) error
|
||||
|
||||
// RemoveNatRule removes a routing NAT rule
|
||||
RemoveNatRule(pair types.RouterPair) error
|
||||
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
AddDNATRule(types.ForwardRule) (types.Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
// todo: do you need a string ID or the complete rule?
|
||||
DeleteDNATRule(types.Rule) error
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/nadoo/ipset"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
@@ -80,12 +80,12 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
protocol types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
var dPortVal, sPortVal string
|
||||
if dPort != nil && dPort.Values != nil {
|
||||
// TODO: we support only one port per rule in current implementation of ACLs
|
||||
@@ -107,7 +107,7 @@ func (m *aclManager) AddPeerFiltering(
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
ipList.addIP(ip.String())
|
||||
return []firewall.Rule{&Rule{
|
||||
return []types.Rule{&Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
@@ -152,11 +152,11 @@ func (m *aclManager) AddPeerFiltering(
|
||||
|
||||
m.updateState()
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
return []types.Rule{rule}, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
func (m *aclManager) DeletePeerRule(rule types.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
@@ -354,7 +354,7 @@ func (m *aclManager) updateState() {
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) {
|
||||
func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action types.Action, ipsetName string) (specs []string) {
|
||||
matchByIP := true
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if ip.String() == "0.0.0.0" {
|
||||
@@ -380,8 +380,8 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A
|
||||
return append(specs, "-j", actionToStr(action))
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
func actionToStr(action types.Action) string {
|
||||
if action == types.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/legacy"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -97,13 +98,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
protocol types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -113,11 +114,11 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
) (types.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -129,14 +130,14 @@ func (m *Manager) AddRouteFiltering(
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeletePeerRule(rule types.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -147,14 +148,14 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) AddNatRule(pair types.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -162,7 +163,7 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
return legacy.SetLegacyRouter(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -200,7 +201,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
firewall.ActionAccept,
|
||||
types.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
@@ -213,12 +214,12 @@ func (m *Manager) AllowNetbird() error {
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
@@ -68,13 +68,13 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
var rule2 []types.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
port := &types.Port{
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -95,8 +95,8 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
port := &types.Port{Values: []int{5353}}
|
||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, types.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset(nil)
|
||||
@@ -141,13 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
var rule2 []types.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
port := &types.Port{
|
||||
Values: []int{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
@@ -214,8 +214,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
port := &types.Port{Values: []int{1000 + i}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -42,11 +42,11 @@ const (
|
||||
type routeFilteringRuleParams struct {
|
||||
Sources []netip.Prefix
|
||||
Destination netip.Prefix
|
||||
Proto firewall.Protocol
|
||||
SPort *firewall.Port
|
||||
DPort *firewall.Port
|
||||
Direction firewall.RuleDirection
|
||||
Action firewall.Action
|
||||
Proto types.Protocol
|
||||
SPort *types.Port
|
||||
DPort *types.Port
|
||||
Direction types.RuleDirection
|
||||
Action types.Action
|
||||
SetName string
|
||||
}
|
||||
|
||||
@@ -106,11 +106,11 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
) (types.Rule, error) {
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
return ruleKey, nil
|
||||
@@ -118,7 +118,7 @@ func (r *router) AddRouteFiltering(
|
||||
|
||||
var setName string
|
||||
if len(sources) > 1 {
|
||||
setName = firewall.GenerateSetName(sources)
|
||||
setName = types.GenerateSetName(sources)
|
||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func (r *router) AddRouteFiltering(
|
||||
return ruleKey, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
func (r *router) DeleteRouteRule(rule types.Rule) error {
|
||||
ruleKey := rule.GetRuleID()
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
@@ -202,7 +202,7 @@ func (r *router) deleteIpSet(setName string) error {
|
||||
}
|
||||
|
||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) AddNatRule(pair types.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
@@ -218,7 +218,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -228,12 +228,12 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) RemoveNatRule(pair types.RouterPair) error {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -247,8 +247,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return err
|
||||
@@ -264,8 +264,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
@@ -293,7 +293,7 @@ func (r *router) SetLegacyManagement(isLegacy bool) {
|
||||
func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
||||
@@ -478,8 +478,8 @@ func (r *router) cleanJumpRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
func (r *router) addNatRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
@@ -514,8 +514,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
func (r *router) removeNatRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
@@ -567,7 +567,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
|
||||
rule = append(rule, "-d", params.Destination.String())
|
||||
|
||||
if params.Proto != firewall.ProtocolALL {
|
||||
if params.Proto != types.ProtocolALL {
|
||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||
rule = append(rule, applyPort("--sport", params.SPort)...)
|
||||
rule = append(rule, applyPort("--dport", params.DPort)...)
|
||||
@@ -578,7 +578,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||
return rule
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
func applyPort(flag string, port *types.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||
require.True(t, exists, "prerouting jump rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
pair := types.RouterPair{
|
||||
ID: "abc",
|
||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
||||
@@ -89,7 +89,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "marking rule should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -114,8 +114,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check inverse rule
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inversePair := types.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -164,7 +164,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -183,8 +183,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.False(t, found, "marking rule should not exist in the manager map")
|
||||
|
||||
// Check inverse rule removal
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inversePair := types.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -226,22 +226,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
proto types.Protocol
|
||||
sPort *types.Port
|
||||
dPort *types.Port
|
||||
direction types.RuleDirection
|
||||
action types.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
proto: types.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
dPort: &types.Port{Values: []int{80}},
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
@@ -251,77 +251,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
proto: types.ProtocolUDP,
|
||||
sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
direction: types.RuleDirectionOUT,
|
||||
action: types.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
proto: types.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
proto: types.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
proto: types.ProtocolTCP,
|
||||
sPort: &types.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
direction: types.RuleDirectionOUT,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
proto: types.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
proto: types.ProtocolTCP,
|
||||
sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &types.Port{Values: []int{22}},
|
||||
direction: types.RuleDirectionOUT,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
proto: types.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
@@ -357,7 +357,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
expectedRule := genRouteFilteringRuleSpec(params)
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
setName := types.GenerateSetName(tt.sources)
|
||||
params.SetName = setName
|
||||
expectedRule = genRouteFilteringRuleSpec(params)
|
||||
|
||||
|
||||
35
client/firewall/legacy/router.go
Normal file
35
client/firewall/legacy/router.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package legacy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Router defines the interface for legacy management operations
|
||||
type Router interface {
|
||||
RemoveAllLegacyRouteRules() error
|
||||
GetLegacyManagement() bool
|
||||
SetLegacyManagement(bool)
|
||||
}
|
||||
|
||||
// SetLegacyRouter sets the route manager to use legacy management
|
||||
func SetLegacyRouter(router Router, isLegacy bool) error {
|
||||
oldLegacy := router.GetLegacyManagement()
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
router.SetLegacyManagement(isLegacy)
|
||||
logrus.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to clean up the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
logrus.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// GetRuleID returns the rule id
|
||||
GetRuleID() string
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
type RuleDirection int
|
||||
|
||||
const (
|
||||
// RuleDirectionIN applies to filters that handlers incoming traffic
|
||||
RuleDirectionIN RuleDirection = iota
|
||||
// RuleDirectionOUT applies to filters that handlers outgoing traffic
|
||||
RuleDirectionOUT
|
||||
)
|
||||
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
const (
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
// ActionDrop is the action to drop a packet
|
||||
ActionDrop
|
||||
)
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
// Netbird client for ACL and routing functionality
|
||||
type Manager interface {
|
||||
Init(stateManager *statemanager.Manager) error
|
||||
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
// RemoveNatRule removes a routing NAT rule
|
||||
RemoveNatRule(pair RouterPair) error
|
||||
|
||||
// SetLegacyManagement sets the legacy management mode
|
||||
SetLegacyManagement(legacy bool) error
|
||||
|
||||
// Reset firewall to the default state
|
||||
Reset(stateManager *statemanager.Manager) error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
// todo: do you need a string ID or the complete rule?
|
||||
DeleteDNATRule(Rule) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
type LegacyManager interface {
|
||||
RemoveAllLegacyRouteRules() error
|
||||
GetLegacyManagement() bool
|
||||
SetLegacyManagement(bool)
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
||||
oldLegacy := router.GetLegacyManagement()
|
||||
|
||||
if oldLegacy != isLegacy {
|
||||
router.SetLegacyManagement(isLegacy)
|
||||
log.Debugf("Set legacy management to %v", isLegacy)
|
||||
}
|
||||
|
||||
// client reconnected to a newer mgmt, we need to clean up the legacy rules
|
||||
if !isLegacy && oldLegacy {
|
||||
if err := router.RemoveAllLegacyRouteRules(); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rules: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Legacy routing rules removed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
|
||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
return prefixes
|
||||
}
|
||||
|
||||
merged := []netip.Prefix{prefixes[0]}
|
||||
for _, prefix := range prefixes[1:] {
|
||||
last := merged[len(merged)-1]
|
||||
if last.Contains(prefix.Addr()) {
|
||||
// If the current prefix is contained within the last merged prefix, skip it
|
||||
continue
|
||||
}
|
||||
if prefix.Contains(last.Addr()) {
|
||||
// If the current prefix contains the last merged prefix, replace it
|
||||
merged[len(merged)-1] = prefix
|
||||
} else {
|
||||
// Otherwise, add the current prefix to the merged list
|
||||
merged = append(merged, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
return addrCmp < 0
|
||||
}
|
||||
|
||||
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
|
||||
return prefixes[i].Bits() > prefixes[j].Bits()
|
||||
})
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func TestGenerateSetName(t *testing.T) {
|
||||
t.Run("Different orders result in same hash", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Result format is correct", func(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := manager.GenerateSetName(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Errorf("Result format is incorrect: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := manager.GenerateSetName(prefixes1)
|
||||
result2 := manager.GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMergeIPRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []netip.Prefix{},
|
||||
expected: []netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Two non-overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partially overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IPv6 ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::/48"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := manager.MergeIPRanges(tt.input)
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
var err error
|
||||
@@ -100,7 +100,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
newRules := make([]types.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -111,7 +111,7 @@ func (m *AclManager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
func (m *AclManager) DeletePeerRule(rule types.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
@@ -234,10 +234,10 @@ func (m *AclManager) Flush() error {
|
||||
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
ipset *nftables.Set,
|
||||
comment string,
|
||||
) (*Rule, error) {
|
||||
@@ -253,7 +253,7 @@ func (m *AclManager) addIOFiltering(
|
||||
|
||||
var expressions []expr.Any
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
if proto != types.ProtocolALL {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
@@ -341,9 +341,9 @@ func (m *AclManager) addIOFiltering(
|
||||
}
|
||||
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
case types.ActionAccept:
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
case firewall.ActionDrop:
|
||||
case types.ActionDrop:
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
@@ -672,7 +672,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
func generatePeerRuleId(ip net.IP, sPort *types.Port, dPort *types.Port, action types.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
@@ -689,7 +689,7 @@ func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, a
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
|
||||
func encodePort(port firewall.Port) []byte {
|
||||
func encodePort(port types.Port) []byte {
|
||||
bs := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
||||
return bs
|
||||
@@ -701,13 +701,13 @@ func ifname(n string) []byte {
|
||||
return b
|
||||
}
|
||||
|
||||
func protoToInt(protocol firewall.Protocol) (uint8, error) {
|
||||
func protoToInt(protocol types.Protocol) (uint8, error) {
|
||||
switch protocol {
|
||||
case firewall.ProtocolTCP:
|
||||
case types.ProtocolTCP:
|
||||
return unix.IPPROTO_TCP, nil
|
||||
case firewall.ProtocolUDP:
|
||||
case types.ProtocolUDP:
|
||||
return unix.IPPROTO_UDP, nil
|
||||
case firewall.ProtocolICMP:
|
||||
case types.ProtocolICMP:
|
||||
return unix.IPPROTO_ICMP, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/legacy"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
@@ -114,13 +115,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -135,11 +136,11 @@ func (m *Manager) AddPeerFiltering(
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
) (types.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -151,7 +152,7 @@ func (m *Manager) AddRouteFiltering(
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeletePeerRule(rule types.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -159,7 +160,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -170,14 +171,14 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) AddNatRule(pair types.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddNatRule(pair)
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -238,7 +239,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||
return legacy.SetLegacyRouter(m.router, isLegacy)
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -330,7 +331,7 @@ func (m *Manager) Flush() error {
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
|
||||
r := &Rule{
|
||||
ruleID: rule.GetRuleID(),
|
||||
}
|
||||
@@ -338,7 +339,7 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "")
|
||||
rule, err := manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{53}}, types.ActionDrop, "", "")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
@@ -200,8 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
port := &types.Port{Values: []int{1000 + i}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -283,20 +283,20 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule")
|
||||
_, err = manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{80}}, types.ActionAccept, "", "test rule")
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
types.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
fw.ActionAccept,
|
||||
&types.Port{Values: []int{443}},
|
||||
types.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
pair := types.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Masquerade: true,
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -167,11 +167,11 @@ func (r *router) createContainers() error {
|
||||
func (r *router) AddRouteFiltering(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
) (types.Rule, error) {
|
||||
|
||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||
@@ -200,7 +200,7 @@ func (r *router) AddRouteFiltering(
|
||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
||||
|
||||
// Handle protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
if proto != types.ProtocolALL {
|
||||
protoNum, err := protoToInt(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
@@ -219,7 +219,7 @@ func (r *router) AddRouteFiltering(
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
|
||||
var verdict expr.VerdictKind
|
||||
if action == firewall.ActionAccept {
|
||||
if action == types.ActionAccept {
|
||||
verdict = expr.VerdictAccept
|
||||
} else {
|
||||
verdict = expr.VerdictDrop
|
||||
@@ -248,7 +248,7 @@ func (r *router) AddRouteFiltering(
|
||||
}
|
||||
|
||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
||||
setName := firewall.GenerateSetName(sources)
|
||||
setName := types.GenerateSetName(sources)
|
||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
||||
@@ -270,7 +270,7 @@ func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
func (r *router) DeleteRouteRule(rule types.Rule) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -307,7 +307,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||
|
||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
sources = firewall.MergeIPRanges(sources)
|
||||
sources = mergeIPRanges(sources)
|
||||
|
||||
set := &nftables.Set{
|
||||
Name: setName,
|
||||
@@ -403,7 +403,7 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||
}
|
||||
|
||||
// AddNatRule appends a nftables rule pair to the nat chain
|
||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) AddNatRule(pair types.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -420,7 +420,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
if err := r.addNatRule(types.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -433,7 +433,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) addNatRule(pair types.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -494,7 +494,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||
},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
@@ -584,7 +584,7 @@ func (r *router) addPostroutingRules() error {
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
func (r *router) addLegacyRouteRule(pair types.RouterPair) error {
|
||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
||||
|
||||
@@ -597,7 +597,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
|
||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
@@ -615,8 +615,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
func (r *router) removeLegacyRouteRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
@@ -651,7 +651,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||
if !strings.HasPrefix(k, types.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
@@ -829,7 +829,7 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
||||
}
|
||||
|
||||
// RemoveNatRule removes the prerouting mark rule
|
||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (r *router) RemoveNatRule(pair types.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
@@ -838,7 +838,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
if err := r.removeNatRule(types.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -854,8 +854,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
func (r *router) removeNatRule(pair types.RouterPair) error {
|
||||
ruleKey := types.GenRuleKey(types.PreroutingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
err := r.conn.DelRule(rule)
|
||||
@@ -931,7 +931,7 @@ func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
func applyPort(port *types.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -987,3 +987,27 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func mergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
if len(prefixes) == 0 {
|
||||
return prefixes
|
||||
}
|
||||
|
||||
merged := []netip.Prefix{prefixes[0]}
|
||||
for _, prefix := range prefixes[1:] {
|
||||
last := merged[len(merged)-1]
|
||||
if last.Contains(prefix.Addr()) {
|
||||
// If the current prefix is contained within the last merged prefix, skip it
|
||||
continue
|
||||
}
|
||||
if prefix.Contains(last.Addr()) {
|
||||
// If the current prefix contains the last merged prefix, replace it
|
||||
merged[len(merged)-1] = prefix
|
||||
} else {
|
||||
// Otherwise, add the current prefix to the merged list
|
||||
merged = append(merged, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
@@ -15,8 +16,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -97,7 +98,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
testingExpression = append(testingExpression, sourceExp...)
|
||||
testingExpression = append(testingExpression, destExp...)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNamePrerouting {
|
||||
@@ -139,7 +140,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
@@ -209,22 +210,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
name string
|
||||
sources []netip.Prefix
|
||||
destination netip.Prefix
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
direction firewall.RuleDirection
|
||||
action firewall.Action
|
||||
proto types.Protocol
|
||||
sPort *types.Port
|
||||
dPort *types.Port
|
||||
direction types.RuleDirection
|
||||
action types.Action
|
||||
expectSet bool
|
||||
}{
|
||||
{
|
||||
name: "Basic TCP rule with single source",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
proto: types.ProtocolTCP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{80}},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
dPort: &types.Port{Values: []int{80}},
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
@@ -234,77 +235,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
proto: types.ProtocolUDP,
|
||||
sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionDrop,
|
||||
direction: types.RuleDirectionOUT,
|
||||
action: types.ActionDrop,
|
||||
expectSet: true,
|
||||
},
|
||||
{
|
||||
name: "All protocols rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
destination: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
proto: firewall.ProtocolALL,
|
||||
proto: types.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "ICMP rule",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||
proto: firewall.ProtocolICMP,
|
||||
proto: types.ProtocolICMP,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionAccept,
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with multiple source ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
||||
proto: types.ProtocolTCP,
|
||||
sPort: &types.Port{Values: []int{80, 443, 8080}},
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
direction: types.RuleDirectionOUT,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "UDP rule with single IP and port range",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
|
||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
proto: firewall.ProtocolUDP,
|
||||
proto: types.ProtocolUDP,
|
||||
sPort: nil,
|
||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true},
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "TCP rule with source and destination ports",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||
proto: firewall.ProtocolTCP,
|
||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &firewall.Port{Values: []int{22}},
|
||||
direction: firewall.RuleDirectionOUT,
|
||||
action: firewall.ActionAccept,
|
||||
proto: types.ProtocolTCP,
|
||||
sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true},
|
||||
dPort: &types.Port{Values: []int{22}},
|
||||
direction: types.RuleDirectionOUT,
|
||||
action: types.ActionAccept,
|
||||
expectSet: false,
|
||||
},
|
||||
{
|
||||
name: "Drop all incoming traffic",
|
||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
destination: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
proto: firewall.ProtocolALL,
|
||||
proto: types.ProtocolALL,
|
||||
sPort: nil,
|
||||
dPort: nil,
|
||||
direction: firewall.RuleDirectionIN,
|
||||
action: firewall.ActionDrop,
|
||||
direction: types.RuleDirectionIN,
|
||||
action: types.ActionDrop,
|
||||
expectSet: false,
|
||||
},
|
||||
}
|
||||
@@ -441,7 +442,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setName := firewall.GenerateSetName(tt.sources)
|
||||
setName := types.GenerateSetName(tt.sources)
|
||||
set, err := r.createIpSet(setName, tt.sources)
|
||||
if err != nil {
|
||||
t.Logf("Failed to create IP set: %v", err)
|
||||
@@ -506,7 +507,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
|
||||
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort, dPort *types.Port, direction types.RuleDirection, action types.Action, expectSet bool) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotNil(t, rule, "Rule should not be nil")
|
||||
@@ -515,21 +516,21 @@ func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, desti
|
||||
if expectSet {
|
||||
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
|
||||
} else if len(sources) == 1 && sources[0].Bits() != 0 {
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if direction == types.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
|
||||
}
|
||||
}
|
||||
|
||||
if direction == firewall.RuleDirectionIN {
|
||||
if direction == types.RuleDirectionIN {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
|
||||
} else {
|
||||
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
|
||||
}
|
||||
|
||||
// Verify protocol
|
||||
if proto != firewall.ProtocolALL {
|
||||
if proto != types.ProtocolALL {
|
||||
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
|
||||
}
|
||||
|
||||
@@ -582,7 +583,7 @@ func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) b
|
||||
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
|
||||
}
|
||||
|
||||
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
func containsPort(exprs []expr.Any, port *types.Port, isSource bool) bool {
|
||||
var offset uint32 = 2 // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
@@ -619,7 +620,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
func containsProtocol(exprs []expr.Any, proto types.Protocol) bool {
|
||||
var metaFound, cmpFound bool
|
||||
expectedProto, _ := protoToInt(proto)
|
||||
for _, e := range exprs {
|
||||
@@ -637,13 +638,13 @@ func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
|
||||
return metaFound && cmpFound
|
||||
}
|
||||
|
||||
func containsAction(exprs []expr.Any, action firewall.Action) bool {
|
||||
func containsAction(exprs []expr.Any, action types.Action) bool {
|
||||
for _, e := range exprs {
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
case types.ActionAccept:
|
||||
return verdict.Kind == expr.VerdictAccept
|
||||
case firewall.ActionDrop:
|
||||
case types.ActionDrop:
|
||||
return verdict.Kind == expr.VerdictDrop
|
||||
}
|
||||
}
|
||||
@@ -714,3 +715,121 @@ func deleteWorkTable() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeIPRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []netip.Prefix
|
||||
expected []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []netip.Prefix{},
|
||||
expected: []netip.Prefix{},
|
||||
},
|
||||
{
|
||||
name: "Single range",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Two non-overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "One range containing another (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Overlapping ranges (different order)",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Multiple overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.128/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partially overlapping ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/23"),
|
||||
netip.MustParsePrefix("192.168.2.0/25"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IPv6 ranges",
|
||||
input: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("2001:db8:1::/48"),
|
||||
netip.MustParsePrefix("2001:db8:2::/48"),
|
||||
},
|
||||
expected: []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := mergeIPRanges(tt.input)
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package test
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/types"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package manager
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// ForwardRule todo figure out better place to this to avoid circular imports
|
||||
type ForwardRule struct {
|
||||
Protocol Protocol
|
||||
DestinationPort Port
|
||||
25
client/firewall/types/ipset.go
Normal file
25
client/firewall/types/ipset.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
||||
func GenerateSetName(sources []netip.Prefix) string {
|
||||
// sort for consistent naming
|
||||
SortPrefixes(sources)
|
||||
|
||||
var sourcesStr strings.Builder
|
||||
for _, src := range sources {
|
||||
sourcesStr.WriteString(src.String())
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
return fmt.Sprintf("nb-%s", shortHash)
|
||||
}
|
||||
71
client/firewall/types/ipset_test.go
Normal file
71
client/firewall/types/ipset_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateSetName(t *testing.T) {
|
||||
t.Run("Different orders result in same hash", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := GenerateSetName(prefixes1)
|
||||
result2 := GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Result format is correct", func(t *testing.T) {
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
result := GenerateSetName(prefixes)
|
||||
|
||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Error matching regex: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Errorf("Result format is incorrect: %s", result)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||
result1 := GenerateSetName([]netip.Prefix{})
|
||||
result2 := GenerateSetName([]netip.Prefix{})
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
|
||||
prefixes1 := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
}
|
||||
prefixes2 := []netip.Prefix{
|
||||
netip.MustParsePrefix("2001:db8::/32"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
result1 := GenerateSetName(prefixes1)
|
||||
result2 := GenerateSetName(prefixes2)
|
||||
|
||||
if result1 != result2 {
|
||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
20
client/firewall/types/netip.go
Normal file
20
client/firewall/types/netip.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
sort.Slice(prefixes, func(i, j int) bool {
|
||||
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
|
||||
if addrCmp != 0 {
|
||||
return addrCmp < 0
|
||||
}
|
||||
|
||||
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
|
||||
return prefixes[i].Bits() > prefixes[j].Bits()
|
||||
})
|
||||
}
|
||||
@@ -1,11 +1,10 @@
|
||||
package manager
|
||||
package types
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Port of the address for firewall rule
|
||||
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||
type Port struct {
|
||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||
IsRange bool
|
||||
@@ -1,7 +1,6 @@
|
||||
package manager
|
||||
package types
|
||||
|
||||
// Protocol is the protocol of the port
|
||||
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
@@ -1,4 +1,4 @@
|
||||
package manager
|
||||
package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
43
client/firewall/types/rule.go
Normal file
43
client/firewall/types/rule.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package types
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
PreroutingFormat = "netbird-prerouting-%s-%t"
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
)
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// GetRuleID returns the rule id
|
||||
GetRuleID() string
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
type RuleDirection int
|
||||
|
||||
const (
|
||||
// RuleDirectionIN applies to filters that handlers incoming traffic
|
||||
RuleDirectionIN RuleDirection = iota
|
||||
// RuleDirectionOUT applies to filters that handlers outgoing traffic
|
||||
RuleDirectionOUT
|
||||
)
|
||||
|
||||
// Action is the action to be taken on a rule
|
||||
type Action int
|
||||
|
||||
const (
|
||||
// ActionAccept is the action to accept a packet
|
||||
ActionAccept Action = iota
|
||||
// ActionDrop is the action to drop a packet
|
||||
ActionDrop
|
||||
)
|
||||
|
||||
func GenRuleKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
}
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -46,7 +47,7 @@ type Manager struct {
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
wgIface IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
nativeFirewall firewall.Firewall
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
@@ -74,7 +75,7 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return create(iface)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Firewall) (*Manager, error) {
|
||||
mgr, err := create(iface)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -134,7 +135,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) AddNatRule(pair types.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
@@ -142,7 +143,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
// RemoveNatRule removes a routing firewall rule
|
||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
func (m *Manager) RemoveNatRule(pair types.RouterPair) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
@@ -155,19 +156,19 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
_ string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
r := Rule{
|
||||
id: uuid.New().String(),
|
||||
ip: ip,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
drop: action == types.ActionDrop,
|
||||
comment: comment,
|
||||
}
|
||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
||||
@@ -188,16 +189,16 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
|
||||
switch proto {
|
||||
case firewall.ProtocolTCP:
|
||||
case types.ProtocolTCP:
|
||||
r.protoLayer = layers.LayerTypeTCP
|
||||
case firewall.ProtocolUDP:
|
||||
case types.ProtocolUDP:
|
||||
r.protoLayer = layers.LayerTypeUDP
|
||||
case firewall.ProtocolICMP:
|
||||
case types.ProtocolICMP:
|
||||
r.protoLayer = layers.LayerTypeICMPv4
|
||||
if r.ipLayer == layers.LayerTypeIPv6 {
|
||||
r.protoLayer = layers.LayerTypeICMPv6
|
||||
}
|
||||
case firewall.ProtocolALL:
|
||||
case types.ProtocolALL:
|
||||
r.protoLayer = layerTypeAll
|
||||
}
|
||||
|
||||
@@ -207,17 +208,17 @@ func (m *Manager) AddPeerFiltering(
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
return []types.Rule{&r}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errRouteNotSupported
|
||||
}
|
||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeleteRouteRule(rule types.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
}
|
||||
@@ -225,7 +226,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeletePeerRule(rule types.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -254,12 +255,12 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
func (m *Manager) DeleteDNATRule(rule types.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
@@ -90,8 +90,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
||||
fw.ActionAccept, "", "allow all")
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolALL, nil, nil,
|
||||
types.ActionAccept, "", "allow all")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -111,10 +111,10 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{1024 + i}},
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.ActionAccept, "", "explicit return")
|
||||
_, err := m.AddPeerFiltering(ip, types.ProtocolTCP,
|
||||
&types.Port{Values: []int{1024 + i}},
|
||||
&types.Port{Values: []int{80}},
|
||||
types.ActionAccept, "", "explicit return")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -125,8 +125,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
||||
fw.ActionDrop, "", "default drop")
|
||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, nil, nil,
|
||||
types.ActionDrop, "", "default drop")
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -587,10 +587,10 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
|
||||
&types.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
types.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -678,10 +678,10 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
|
||||
&types.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
types.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -796,10 +796,10 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
|
||||
&types.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
types.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -883,10 +883,10 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
||||
&fw.Port{Values: []int{80}},
|
||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP,
|
||||
&types.Port{Values: []int{80}},
|
||||
nil,
|
||||
fw.ActionAccept, "", "return traffic")
|
||||
types.ActionAccept, "", "return traffic")
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -43,12 +43,12 @@ func TestManagerCreate(t *testing.T) {
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
t.Errorf("failed to create Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
t.Error("Manager is nil")
|
||||
t.Error("Firewall is nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,14 +63,14 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
t.Errorf("failed to create Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
action := fw.ActionDrop
|
||||
proto := types.ProtocolTCP
|
||||
port := &types.Port{Values: []int{80}}
|
||||
action := types.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
@@ -97,14 +97,14 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
t.Errorf("failed to create Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
action := fw.ActionDrop
|
||||
proto := types.ProtocolTCP
|
||||
port := &types.Port{Values: []int{80}}
|
||||
action := types.ActionDrop
|
||||
comment := "Test rule 2"
|
||||
|
||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
@@ -138,7 +138,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
expDir fw.RuleDirection
|
||||
expDir types.RuleDirection
|
||||
ip net.IP
|
||||
dPort uint16
|
||||
hook func([]byte) bool
|
||||
@@ -147,7 +147,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
{
|
||||
name: "Test Outgoing UDP Packet Hook",
|
||||
in: false,
|
||||
expDir: fw.RuleDirectionOUT,
|
||||
expDir: types.RuleDirectionOUT,
|
||||
ip: net.IPv4(10, 168, 0, 1),
|
||||
dPort: 8000,
|
||||
hook: func([]byte) bool { return true },
|
||||
@@ -155,7 +155,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
{
|
||||
name: "Test Incoming UDP Packet Hook",
|
||||
in: true,
|
||||
expDir: fw.RuleDirectionIN,
|
||||
expDir: types.RuleDirectionIN,
|
||||
ip: net.IPv6loopback,
|
||||
dPort: 9000,
|
||||
hook: func([]byte) bool { return false },
|
||||
@@ -217,14 +217,14 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
t.Errorf("failed to create Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []int{80}}
|
||||
action := fw.ActionDrop
|
||||
proto := types.ProtocolTCP
|
||||
port := &types.Port{Values: []int{80}}
|
||||
action := types.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
||||
@@ -235,7 +235,7 @@ func TestManagerReset(t *testing.T) {
|
||||
|
||||
err = m.Reset(nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
t.Errorf("failed to reset Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
|
||||
m, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
t.Errorf("failed to create Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
m.wgNetwork = &net.IPNet{
|
||||
@@ -260,8 +260,8 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
}
|
||||
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
proto := types.ProtocolUDP
|
||||
action := types.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
||||
@@ -304,7 +304,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
}
|
||||
|
||||
if err = m.Reset(nil); err != nil {
|
||||
t.Errorf("failed to reset Manager: %v", err)
|
||||
t.Errorf("failed to reset Firewall: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -319,7 +319,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
// creating manager instance
|
||||
manager, err := Create(iface)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
t.Fatalf("Failed to create Firewall: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
require.NoError(t, manager.Reset(nil))
|
||||
@@ -463,8 +463,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.100")
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
port := &types.Port{Values: []int{1000 + i}}
|
||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic")
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
)
|
||||
|
||||
type RuleID string
|
||||
@@ -19,12 +19,12 @@ func (r RuleID) GetRuleID() string {
|
||||
func GenerateRouteRuleKey(
|
||||
sources []netip.Prefix,
|
||||
destination netip.Prefix,
|
||||
proto manager.Protocol,
|
||||
sPort *manager.Port,
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
proto types.Protocol,
|
||||
sPort *types.Port,
|
||||
dPort *types.Port,
|
||||
action types.Action,
|
||||
) RuleID {
|
||||
manager.SortPrefixes(sources)
|
||||
types.SortPrefixes(sources)
|
||||
|
||||
h := sha256.New()
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
@@ -30,17 +31,17 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
firewall _interface.Firewall
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
peerRulesPairs map[id.RuleID][]types.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
func NewDefaultManager(fm _interface.Firewall) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
firewall: fm,
|
||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||
peerRulesPairs: make(map[id.RuleID][]types.Rule),
|
||||
routeRules: make(map[id.RuleID]struct{}),
|
||||
}
|
||||
}
|
||||
@@ -132,7 +133,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
)
|
||||
}
|
||||
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
newRulePairs := make(map[id.RuleID][]types.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
|
||||
for _, r := range rules {
|
||||
@@ -251,7 +252,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
) (id.RuleID, []types.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
@@ -267,13 +268,13 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
var port *types.Port
|
||||
if r.Port != "" {
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
|
||||
}
|
||||
port = &firewall.Port{
|
||||
port = &types.Port{
|
||||
Values: []int{value},
|
||||
}
|
||||
}
|
||||
@@ -283,7 +284,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
var rules []types.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
@@ -304,12 +305,12 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
protocol types.Protocol,
|
||||
port *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
@@ -320,12 +321,12 @@ func (d *DefaultManager) addInRules(
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
protocol types.Protocol,
|
||||
port *types.Port,
|
||||
action types.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
) ([]types.Rule, error) {
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -341,10 +342,10 @@ func (d *DefaultManager) addOutRules(
|
||||
// getPeerRuleID() returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
proto types.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
port *types.Port,
|
||||
action types.Action,
|
||||
comment string,
|
||||
) id.RuleID {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
||||
@@ -491,7 +492,7 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]types.Rule) {
|
||||
log.Debugf("rollback ACL to previous state")
|
||||
for _, rules := range newRulePairs {
|
||||
for _, rule := range rules {
|
||||
@@ -502,49 +503,49 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (types.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
return firewall.ProtocolTCP, nil
|
||||
return types.ProtocolTCP, nil
|
||||
case mgmProto.RuleProtocol_UDP:
|
||||
return firewall.ProtocolUDP, nil
|
||||
return types.ProtocolUDP, nil
|
||||
case mgmProto.RuleProtocol_ICMP:
|
||||
return firewall.ProtocolICMP, nil
|
||||
return types.ProtocolICMP, nil
|
||||
case mgmProto.RuleProtocol_ALL:
|
||||
return firewall.ProtocolALL, nil
|
||||
return types.ProtocolALL, nil
|
||||
default:
|
||||
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
return types.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) bool {
|
||||
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
|
||||
func shouldSkipInvertedRule(protocol types.Protocol, port *types.Port) bool {
|
||||
return protocol == types.ProtocolALL || protocol == types.ProtocolICMP || port == nil
|
||||
}
|
||||
|
||||
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
|
||||
func convertFirewallAction(action mgmProto.RuleAction) (types.Action, error) {
|
||||
switch action {
|
||||
case mgmProto.RuleAction_ACCEPT:
|
||||
return firewall.ActionAccept, nil
|
||||
return types.ActionAccept, nil
|
||||
case mgmProto.RuleAction_DROP:
|
||||
return firewall.ActionDrop, nil
|
||||
return types.ActionDrop, nil
|
||||
default:
|
||||
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||
return types.ActionDrop, fmt.Errorf("invalid action type: %d", action)
|
||||
}
|
||||
}
|
||||
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) *types.Port {
|
||||
if portInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if portInfo.GetPort() != 0 {
|
||||
return &firewall.Port{
|
||||
return &types.Port{
|
||||
Values: []int{int(portInfo.GetPort())},
|
||||
}
|
||||
}
|
||||
|
||||
if portInfo.GetRange() != nil {
|
||||
return &firewall.Port{
|
||||
return &types.Port{
|
||||
IsRange: true,
|
||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
@@ -56,7 +56,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
defer func(fw _interface.Firewall) {
|
||||
_ = fw.Reset(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
@@ -349,7 +349,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
t.Errorf("create firewall: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(fw manager.Manager) {
|
||||
defer func(fw _interface.Firewall) {
|
||||
_ = fw.Reset(nil)
|
||||
}(fw)
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
@@ -9,7 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -19,13 +20,13 @@ const (
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
firewall _interface.Firewall
|
||||
|
||||
fwRules []firewall.Rule
|
||||
fwRules []types.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager) *Manager {
|
||||
func NewManager(fw _interface.Firewall) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
}
|
||||
@@ -79,7 +80,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (h *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
dport := &types.Port{
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
}
|
||||
@@ -88,7 +89,7 @@ func (h *Manager) allowDNSFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, types.ProtocolUDP, nil, dport, types.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
|
||||
@@ -25,7 +25,8 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -169,7 +170,7 @@ type Engine struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
firewall firewallManager.Manager
|
||||
firewall _interface.Firewall
|
||||
routeManager routemanager.Manager
|
||||
acl acl.Manager
|
||||
dnsForwardMgr *dnsfwd.Manager
|
||||
@@ -504,15 +505,15 @@ func (e *Engine) initFirewall() error {
|
||||
}
|
||||
|
||||
rosenpassPort := e.rpManager.GetAddress().Port
|
||||
port := firewallManager.Port{Values: []int{rosenpassPort}}
|
||||
port := types.Port{Values: []int{rosenpassPort}}
|
||||
|
||||
// this rule is static and will be torn down on engine down by the firewall manager
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
net.IP{0, 0, 0, 0},
|
||||
firewallManager.ProtocolUDP,
|
||||
types.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
firewallManager.ActionAccept,
|
||||
types.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
); err != nil {
|
||||
@@ -540,10 +541,10 @@ func (e *Engine) blockLanAccess() {
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
[]netip.Prefix{v4},
|
||||
network,
|
||||
firewallManager.ProtocolALL,
|
||||
types.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
firewallManager.ActionDrop,
|
||||
types.ActionDrop,
|
||||
); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
|
||||
}
|
||||
@@ -1774,7 +1775,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
||||
forwardingRules := make([]types.ForwardRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
proto, err := convertToFirewallProtocol(rule.GetProtocol())
|
||||
if err != nil {
|
||||
@@ -1800,7 +1801,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||
continue
|
||||
}
|
||||
|
||||
forwardRule := firewallManager.ForwardRule{
|
||||
forwardRule := types.ForwardRule{
|
||||
Protocol: proto,
|
||||
DestinationPort: *dstPortInfo,
|
||||
TranslatedAddress: translateIP,
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
)
|
||||
|
||||
@@ -14,192 +14,192 @@ func (e *Engine) mocForwardRules() {
|
||||
e.ingressGatewayMgr = ingressgw.NewManager(e.firewall)
|
||||
}
|
||||
err := e.ingressGatewayMgr.Update(
|
||||
[]firewallManager.ForwardRule{
|
||||
[]types.ForwardRule{
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: false, Values: []int{10000}},
|
||||
DestinationPort: types.Port{IsRange: false, Values: []int{10000}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: false, Values: []int{20000}},
|
||||
TranslatedPort: types.Port{IsRange: false, Values: []int{20000}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10100, 10199}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10100, 10199}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20100, 20199}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20100, 20199}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10200, 10299}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10200, 10299}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20200, 20299}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20200, 20299}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10300, 10399}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10300, 10399}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20300, 20399}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20300, 20399}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10100, 10199}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10100, 10199}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20100, 20199}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20100, 20199}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10400, 10499}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10400, 10499}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20400, 20499}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20400, 20499}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10500, 10599}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10500, 10599}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20500, 20599}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20500, 20599}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10600, 10699}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10600, 10699}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20600, 20699}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20600, 20699}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10700, 10799}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10700, 10799}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20700, 20799}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20700, 20799}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10800, 10899}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10800, 10899}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20800, 20899}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20800, 20899}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10900, 10999}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{10900, 10999}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20900, 20999}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{20900, 20999}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11000, 11099}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11000, 11099}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21000, 21099}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21000, 21099}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11100, 11199}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11100, 11199}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21100, 21199}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21100, 21199}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11200, 11299}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11200, 11299}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21200, 21299}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21200, 21299}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11300, 11399}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11300, 11399}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21300, 21399}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21300, 21399}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11400, 11499}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11400, 11499}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21400, 21499}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21400, 21499}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11500, 11599}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11500, 11599}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21500, 21599}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21500, 21599}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11600, 11699}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11600, 11699}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21600, 21699}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21600, 21699}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11700, 11799}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11700, 11799}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21700, 21799}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21700, 21799}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11800, 11899}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11800, 11899}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21800, 21899}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21800, 21899}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11900, 11999}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{11900, 11999}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21900, 21999}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{21900, 21999}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12000, 12099}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12000, 12099}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22000, 22099}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22000, 22099}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12100, 12199}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12100, 12199}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22100, 22199}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22100, 22199}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12200, 12299}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12200, 12299}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22200, 22299}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22200, 22299}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12300, 12399}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12300, 12399}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22300, 22399}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22300, 22399}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12400, 12499}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12400, 12499}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22400, 22499}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22400, 22499}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12500, 12599}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12500, 12599}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22500, 22599}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22500, 22599}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12600, 12699}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12600, 12699}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22600, 22699}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22600, 22699}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12700, 12799}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12700, 12799}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22700, 22799}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22700, 22799}},
|
||||
},
|
||||
{
|
||||
Protocol: "tcp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12800, 12899}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12800, 12899}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.31.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22800, 22899}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22800, 22899}},
|
||||
},
|
||||
{
|
||||
Protocol: "udp",
|
||||
DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12900, 12999}},
|
||||
DestinationPort: types.Port{IsRange: true, Values: []int{12900, 12999}},
|
||||
TranslatedAddress: netip.MustParseAddr("100.64.10.206"),
|
||||
TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22900, 22999}},
|
||||
TranslatedPort: types.Port{IsRange: true, Values: []int{22900, 22999}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -8,29 +8,30 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
)
|
||||
|
||||
type RulePair struct {
|
||||
firewallManager.ForwardRule
|
||||
firewallManager.Rule
|
||||
types.ForwardRule
|
||||
types.Rule
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
firewallManager firewallManager.Manager
|
||||
firewall _interface.Firewall
|
||||
|
||||
rules map[string]RulePair // keys is the ID of the ForwardRule
|
||||
rulesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(firewall firewallManager.Manager) *Manager {
|
||||
func NewManager(firewall _interface.Firewall) *Manager {
|
||||
return &Manager{
|
||||
firewallManager: firewall,
|
||||
rules: make(map[string]RulePair),
|
||||
firewall: firewall,
|
||||
rules: make(map[string]RulePair),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error {
|
||||
func (h *Manager) Update(forwardRules []types.ForwardRule) error {
|
||||
h.rulesMu.Lock()
|
||||
defer h.rulesMu.Unlock()
|
||||
|
||||
@@ -48,7 +49,7 @@ func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error {
|
||||
continue
|
||||
}
|
||||
|
||||
rule, err := h.firewallManager.AddDNATRule(fwdRule)
|
||||
rule, err := h.firewall.AddDNATRule(fwdRule)
|
||||
if err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to add forward rule '%s': %v", fwdRule.String(), err))
|
||||
continue
|
||||
@@ -62,7 +63,7 @@ func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error {
|
||||
|
||||
// Remove deleted rules
|
||||
for id, rulePair := range toDelete {
|
||||
if err := h.firewallManager.DeleteDNATRule(rulePair.Rule); err != nil {
|
||||
if err := h.firewall.DeleteDNATRule(rulePair.Rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
|
||||
}
|
||||
delete(h.rules, id)
|
||||
@@ -78,18 +79,18 @@ func (h *Manager) Close() error {
|
||||
log.Infof("clean up all forward rules (%d)", len(h.rules))
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range h.rules {
|
||||
if err := h.firewallManager.DeleteDNATRule(rule.Rule); err != nil {
|
||||
if err := h.firewall.DeleteDNATRule(rule.Rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
func (h *Manager) Rules() []firewallManager.ForwardRule {
|
||||
func (h *Manager) Rules() []types.ForwardRule {
|
||||
h.rulesMu.Lock()
|
||||
defer h.rulesMu.Unlock()
|
||||
|
||||
rules := make([]firewallManager.ForwardRule, 0, len(h.rules))
|
||||
rules := make([]types.ForwardRule, 0, len(h.rules))
|
||||
for _, rulePair := range h.rules {
|
||||
rules = append(rules, rulePair.ForwardRule)
|
||||
}
|
||||
|
||||
@@ -6,39 +6,39 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (types.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
return firewallManager.ProtocolTCP, nil
|
||||
return types.ProtocolTCP, nil
|
||||
case mgmProto.RuleProtocol_UDP:
|
||||
return firewallManager.ProtocolUDP, nil
|
||||
return types.ProtocolUDP, nil
|
||||
case mgmProto.RuleProtocol_ICMP:
|
||||
return firewallManager.ProtocolICMP, nil
|
||||
return types.ProtocolICMP, nil
|
||||
case mgmProto.RuleProtocol_ALL:
|
||||
return firewallManager.ProtocolALL, nil
|
||||
return types.ProtocolALL, nil
|
||||
default:
|
||||
return firewallManager.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
return types.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
// convertPortInfo todo: write validation for portInfo
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewallManager.Port {
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) *types.Port {
|
||||
if portInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if portInfo.GetPort() != 0 {
|
||||
return &firewallManager.Port{
|
||||
return &types.Port{
|
||||
Values: []int{int(portInfo.GetPort())},
|
||||
}
|
||||
}
|
||||
|
||||
if portInfo.GetRange() != nil {
|
||||
return &firewallManager.Port{
|
||||
return &types.Port{
|
||||
IsRange: true,
|
||||
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
@@ -44,7 +44,7 @@ type Manager interface {
|
||||
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
EnableServerRouter(firewall _interface.Firewall) error
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
return routeselector.NewRouteSelector()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
func (m *DefaultManager) EnableServerRouter(firewall _interface.Firewall) error {
|
||||
if m.disableServerRoutes {
|
||||
log.Info("server routes are disabled")
|
||||
return nil
|
||||
|
||||
@@ -3,7 +3,7 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
@@ -78,7 +78,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
|
||||
|
||||
}
|
||||
|
||||
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
func (m *MockManager) EnableServerRouter(firewall _interface.Firewall) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
func newServerRouter(context.Context, iface.IWGIface, _interface.Firewall, *peer.Status) (*serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/interface"
|
||||
"github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
@@ -21,12 +22,12 @@ type serverRouter struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
firewall firewall.Manager
|
||||
firewall _interface.Firewall
|
||||
wgInterface iface.IWGIface
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall _interface.Firewall, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
return &serverRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
@@ -167,7 +168,7 @@ func (m *serverRouter) cleanUp() {
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
}
|
||||
|
||||
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
func routeToRouterPair(route *route.Route) (types.RouterPair, error) {
|
||||
// TODO: add ipv6
|
||||
source := getDefaultPrefix(route.Network)
|
||||
|
||||
@@ -177,7 +178,7 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
|
||||
destination = getDefaultPrefix(destination)
|
||||
}
|
||||
|
||||
return firewall.RouterPair{
|
||||
return types.RouterPair{
|
||||
ID: route.ID,
|
||||
Source: source,
|
||||
Destination: destination,
|
||||
|
||||
@@ -3,7 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/types"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user