diff --git a/client/firewall/create.go b/client/firewall/factory.go similarity index 80% rename from client/firewall/create.go rename to client/firewall/factory.go index 9466f4b4d..67bfe8902 100644 --- a/client/firewall/create.go +++ b/client/firewall/factory.go @@ -8,13 +8,13 @@ import ( log "github.com/sirupsen/logrus" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" "github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/internal/statemanager" ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (_interface.Firewall, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/factory_linux.go similarity index 93% rename from client/firewall/create_linux.go rename to client/firewall/factory_linux.go index 076d08ec2..f0b2aa73f 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/factory_linux.go @@ -11,8 +11,8 @@ import ( "github.com/google/nftables" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/interface" nbiptables "github.com/netbirdio/netbird/client/firewall/iptables" - firewall "github.com/netbirdio/netbird/client/firewall/manager" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -33,7 +33,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers @@ -50,7 +50,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal return createUserspaceFirewall(iface, fm) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (_interface.Firewall, error) { fm, err := createFW(iface) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) @@ -63,7 +63,7 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) return fm, nil } -func createFW(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper) (_interface.Firewall, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") @@ -77,7 +77,7 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) { } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm _interface.Firewall) (_interface.Firewall, error) { var errUsp error if fm != nil { fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) diff --git a/client/firewall/interface/firewall.go b/client/firewall/interface/firewall.go new file mode 100644 index 000000000..e84f15345 --- /dev/null +++ b/client/firewall/interface/firewall.go @@ -0,0 +1,67 @@ +package _interface + +import ( + "net" + "net/netip" + + "github.com/netbirdio/netbird/client/firewall/types" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// Firewall is the high level abstraction of a firewall manager +// +// It declares methods which handle actions required by the +// Netbird client for ACL and routing functionality +type Firewall interface { + Init(stateManager *statemanager.Manager) error + + // AllowNetbird allows netbird interface traffic + AllowNetbird() error + + // AddPeerFiltering adds a rule to the firewall + // + // If comment argument is empty firewall manager should set + // rule ID as comment for the rule + AddPeerFiltering( + ip net.IP, + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, + ipsetName string, + comment string, + ) ([]types.Rule, error) + + // DeletePeerRule from the firewall by rule definition + DeletePeerRule(rule types.Rule) error + + // IsServerRouteSupported returns true if the firewall supports server side routing operations + IsServerRouteSupported() bool + + AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error) + + // DeleteRouteRule deletes a routing rule + DeleteRouteRule(rule types.Rule) error + + // AddNatRule inserts a routing NAT rule + AddNatRule(pair types.RouterPair) error + + // RemoveNatRule removes a routing NAT rule + RemoveNatRule(pair types.RouterPair) error + + // SetLegacyManagement sets the legacy management mode + SetLegacyManagement(legacy bool) error + + // Reset firewall to the default state + Reset(stateManager *statemanager.Manager) error + + // Flush the changes to firewall controller + Flush() error + + // AddDNATRule adds a DNAT rule + AddDNATRule(types.ForwardRule) (types.Rule, error) + + // DeleteDNATRule deletes a DNAT rule + // todo: do you need a string ID or the complete rule? + DeleteDNATRule(types.Rule) error +} diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 2592ff840..323a0ee54 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -10,7 +10,7 @@ import ( "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -80,12 +80,12 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error { func (m *aclManager) AddPeerFiltering( ip net.IP, - protocol firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, + protocol types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, ipsetName string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { var dPortVal, sPortVal string if dPort != nil && dPort.Values != nil { // TODO: we support only one port per rule in current implementation of ACLs @@ -107,7 +107,7 @@ func (m *aclManager) AddPeerFiltering( // if ruleset already exists it means we already have the firewall rule // so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager. ipList.addIP(ip.String()) - return []firewall.Rule{&Rule{ + return []types.Rule{&Rule{ ruleID: uuid.New().String(), ipsetName: ipsetName, ip: ip.String(), @@ -152,11 +152,11 @@ func (m *aclManager) AddPeerFiltering( m.updateState() - return []firewall.Rule{rule}, nil + return []types.Rule{rule}, nil } // DeletePeerRule from the firewall by rule definition -func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { +func (m *aclManager) DeletePeerRule(rule types.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -354,7 +354,7 @@ func (m *aclManager) updateState() { } // filterRuleSpecs returns the specs of a filtering rule -func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.Action, ipsetName string) (specs []string) { +func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action types.Action, ipsetName string) (specs []string) { matchByIP := true // don't use IP matching if IP is ip 0.0.0.0 if ip.String() == "0.0.0.0" { @@ -380,8 +380,8 @@ func filterRuleSpecs(ip net.IP, protocol, sPort, dPort string, action firewall.A return append(specs, "-j", actionToStr(action)) } -func actionToStr(action firewall.Action) string { - if action == firewall.ActionAccept { +func actionToStr(action types.Action) string { + if action == types.ActionAccept { return "ACCEPT" } return "DROP" diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 4a977aea0..517647e14 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -12,7 +12,8 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/legacy" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -97,13 +98,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { // Comment will be ignored because some system this feature is not supported func (m *Manager) AddPeerFiltering( ip net.IP, - protocol firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, + protocol types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, ipsetName string, _ string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -113,11 +114,11 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( sources []netip.Prefix, destination netip.Prefix, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, -) (firewall.Rule, error) { + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, +) (types.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -129,14 +130,14 @@ func (m *Manager) AddRouteFiltering( } // DeletePeerRule from the firewall by rule definition -func (m *Manager) DeletePeerRule(rule firewall.Rule) error { +func (m *Manager) DeletePeerRule(rule types.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() return m.aclMgr.DeletePeerRule(rule) } -func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { +func (m *Manager) DeleteRouteRule(rule types.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -147,14 +148,14 @@ func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) AddNatRule(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair types.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() return m.router.AddNatRule(pair) } -func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair types.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -162,7 +163,7 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { } func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + return legacy.SetLegacyRouter(m.router, isLegacy) } // Reset firewall to the default state @@ -200,7 +201,7 @@ func (m *Manager) AllowNetbird() error { "all", nil, nil, - firewall.ActionAccept, + types.ActionAccept, "", "", ) @@ -213,12 +214,12 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } -func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { +func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) { return nil, fmt.Errorf("not implemented") } // DeleteDNATRule deletes a DNAT rule -func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { +func (m *Manager) DeleteDNATRule(rule types.Rule) error { return nil } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index fe0bc86de..8e989584b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -9,7 +9,7 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/stretchr/testify/require" - fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface" ) @@ -68,13 +68,13 @@ func TestIptablesManager(t *testing.T) { time.Sleep(time.Second) }() - var rule2 []fw.Rule + var rule2 []types.Rule t.Run("add second rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.3") - port := &fw.Port{ + port := &types.Port{ Values: []int{8043: 8046}, } - rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range") + rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") for _, r := range rule2 { @@ -95,8 +95,8 @@ func TestIptablesManager(t *testing.T) { t.Run("reset check", func(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") - port := &fw.Port{Values: []int{5353}} - _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") + port := &types.Port{Values: []int{5353}} + _, err = manager.AddPeerFiltering(ip, "udp", nil, port, types.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset(nil) @@ -141,13 +141,13 @@ func TestIptablesManagerIPSet(t *testing.T) { time.Sleep(time.Second) }() - var rule2 []fw.Rule + var rule2 []types.Rule t.Run("add second rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.3") - port := &fw.Port{ + port := &types.Port{ Values: []int{443}, } - rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range") + rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, types.ActionAccept, "default", "accept HTTPS traffic from ports range") for _, r := range rule2 { require.NoError(t, err, "failed to add rule") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") @@ -214,8 +214,8 @@ func TestIptablesCreatePerformance(t *testing.T) { ip := net.ParseIP("10.20.0.100") start := time.Now() for i := 0; i < testMax; i++ { - port := &fw.Port{Values: []int{1000 + i}} - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") + port := &types.Port{Values: []int{1000 + i}} + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index d067a3e7b..66d72016c 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -42,11 +42,11 @@ const ( type routeFilteringRuleParams struct { Sources []netip.Prefix Destination netip.Prefix - Proto firewall.Protocol - SPort *firewall.Port - DPort *firewall.Port - Direction firewall.RuleDirection - Action firewall.Action + Proto types.Protocol + SPort *types.Port + DPort *types.Port + Direction types.RuleDirection + Action types.Action SetName string } @@ -106,11 +106,11 @@ func (r *router) init(stateManager *statemanager.Manager) error { func (r *router) AddRouteFiltering( sources []netip.Prefix, destination netip.Prefix, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, -) (firewall.Rule, error) { + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, +) (types.Rule, error) { ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { return ruleKey, nil @@ -118,7 +118,7 @@ func (r *router) AddRouteFiltering( var setName string if len(sources) > 1 { - setName = firewall.GenerateSetName(sources) + setName = types.GenerateSetName(sources) if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { return nil, fmt.Errorf("create or get ipset: %w", err) } @@ -146,7 +146,7 @@ func (r *router) AddRouteFiltering( return ruleKey, nil } -func (r *router) DeleteRouteRule(rule firewall.Rule) error { +func (r *router) DeleteRouteRule(rule types.Rule) error { ruleKey := rule.GetRuleID() if rule, exists := r.rules[ruleKey]; exists { @@ -202,7 +202,7 @@ func (r *router) deleteIpSet(setName string) error { } // AddNatRule inserts an iptables rule pair into the nat chain -func (r *router) AddNatRule(pair firewall.RouterPair) error { +func (r *router) AddNatRule(pair types.RouterPair) error { if r.legacyManagement { log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) if err := r.addLegacyRouteRule(pair); err != nil { @@ -218,7 +218,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { return fmt.Errorf("add nat rule: %w", err) } - if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + if err := r.addNatRule(types.GetInversePair(pair)); err != nil { return fmt.Errorf("add inverse nat rule: %w", err) } @@ -228,12 +228,12 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { } // RemoveNatRule removes an iptables rule pair from forwarding and nat chains -func (r *router) RemoveNatRule(pair firewall.RouterPair) error { +func (r *router) RemoveNatRule(pair types.RouterPair) error { if err := r.removeNatRule(pair); err != nil { return fmt.Errorf("remove nat rule: %w", err) } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + if err := r.removeNatRule(types.GetInversePair(pair)); err != nil { return fmt.Errorf("remove inverse nat rule: %w", err) } @@ -247,8 +247,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { } // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls -func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) +func (r *router) addLegacyRouteRule(pair types.RouterPair) error { + ruleKey := types.GenRuleKey(types.ForwardingFormat, pair) if err := r.removeLegacyRouteRule(pair); err != nil { return err @@ -264,8 +264,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { return nil } -func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) +func (r *router) removeLegacyRouteRule(pair types.RouterPair) error { + ruleKey := types.GenRuleKey(types.ForwardingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { @@ -293,7 +293,7 @@ func (r *router) SetLegacyManagement(isLegacy bool) { func (r *router) RemoveAllLegacyRouteRules() error { var merr *multierror.Error for k, rule := range r.rules { - if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + if !strings.HasPrefix(k, types.ForwardingFormatPrefix) { continue } if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { @@ -478,8 +478,8 @@ func (r *router) cleanJumpRules() error { return nil } -func (r *router) addNatRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.NatFormat, pair) +func (r *router) addNatRule(pair types.RouterPair) error { + ruleKey := types.GenRuleKey(types.NatFormat, pair) if rule, exists := r.rules[ruleKey]; exists { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil { @@ -514,8 +514,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { return nil } -func (r *router) removeNatRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.NatFormat, pair) +func (r *router) removeNatRule(pair types.RouterPair) error { + ruleKey := types.GenRuleKey(types.NatFormat, pair) if rule, exists := r.rules[ruleKey]; exists { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil { @@ -567,7 +567,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { rule = append(rule, "-d", params.Destination.String()) - if params.Proto != firewall.ProtocolALL { + if params.Proto != types.ProtocolALL { rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, applyPort("--sport", params.SPort)...) rule = append(rule, applyPort("--dport", params.DPort)...) @@ -578,7 +578,7 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { return rule } -func applyPort(flag string, port *firewall.Port) []string { +func applyPort(flag string, port *types.Port) []string { if port == nil { return nil } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 861bf8601..8368de660 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/firewall/types" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -54,7 +54,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING) require.True(t, exists, "prerouting jump rule should exist") - pair := firewall.RouterPair{ + pair := types.RouterPair{ ID: "abc", Source: netip.MustParsePrefix("100.100.100.1/32"), Destination: netip.MustParsePrefix("100.100.100.0/24"), @@ -89,7 +89,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { err = manager.AddNatRule(testCase.InputPair) require.NoError(t, err, "marking rule should be inserted") - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair) markingRule := []string{ "-i", ifaceMock.Name(), "-m", "conntrack", @@ -114,8 +114,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) { } // Check inverse rule - inversePair := firewall.GetInversePair(testCase.InputPair) - inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair) + inversePair := types.GetInversePair(testCase.InputPair) + inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair) inverseMarkingRule := []string{ "!", "-i", ifaceMock.Name(), "-m", "conntrack", @@ -164,7 +164,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRuleKey := types.GenRuleKey(types.NatFormat, testCase.InputPair) markingRule := []string{ "-i", ifaceMock.Name(), "-m", "conntrack", @@ -183,8 +183,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { require.False(t, found, "marking rule should not exist in the manager map") // Check inverse rule removal - inversePair := firewall.GetInversePair(testCase.InputPair) - inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair) + inversePair := types.GetInversePair(testCase.InputPair) + inverseRuleKey := types.GenRuleKey(types.NatFormat, inversePair) inverseMarkingRule := []string{ "!", "-i", ifaceMock.Name(), "-m", "conntrack", @@ -226,22 +226,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) { name string sources []netip.Prefix destination netip.Prefix - proto firewall.Protocol - sPort *firewall.Port - dPort *firewall.Port - direction firewall.RuleDirection - action firewall.Action + proto types.Protocol + sPort *types.Port + dPort *types.Port + direction types.RuleDirection + action types.Action expectSet bool }{ { name: "Basic TCP rule with single source", sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, destination: netip.MustParsePrefix("10.0.0.0/24"), - proto: firewall.ProtocolTCP, + proto: types.ProtocolTCP, sPort: nil, - dPort: &firewall.Port{Values: []int{80}}, - direction: firewall.RuleDirectionIN, - action: firewall.ActionAccept, + dPort: &types.Port{Values: []int{80}}, + direction: types.RuleDirectionIN, + action: types.ActionAccept, expectSet: false, }, { @@ -251,77 +251,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) { netip.MustParsePrefix("192.168.0.0/16"), }, destination: netip.MustParsePrefix("10.0.0.0/8"), - proto: firewall.ProtocolUDP, - sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + proto: types.ProtocolUDP, + sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true}, dPort: nil, - direction: firewall.RuleDirectionOUT, - action: firewall.ActionDrop, + direction: types.RuleDirectionOUT, + action: types.ActionDrop, expectSet: true, }, { name: "All protocols rule", sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, destination: netip.MustParsePrefix("0.0.0.0/0"), - proto: firewall.ProtocolALL, + proto: types.ProtocolALL, sPort: nil, dPort: nil, - direction: firewall.RuleDirectionIN, - action: firewall.ActionAccept, + direction: types.RuleDirectionIN, + action: types.ActionAccept, expectSet: false, }, { name: "ICMP rule", sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, destination: netip.MustParsePrefix("10.0.0.0/8"), - proto: firewall.ProtocolICMP, + proto: types.ProtocolICMP, sPort: nil, dPort: nil, - direction: firewall.RuleDirectionIN, - action: firewall.ActionAccept, + direction: types.RuleDirectionIN, + action: types.ActionAccept, expectSet: false, }, { name: "TCP rule with multiple source ports", sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, destination: netip.MustParsePrefix("192.168.0.0/16"), - proto: firewall.ProtocolTCP, - sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + proto: types.ProtocolTCP, + sPort: &types.Port{Values: []int{80, 443, 8080}}, dPort: nil, - direction: firewall.RuleDirectionOUT, - action: firewall.ActionAccept, + direction: types.RuleDirectionOUT, + action: types.ActionAccept, expectSet: false, }, { name: "UDP rule with single IP and port range", sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, destination: netip.MustParsePrefix("10.0.0.0/24"), - proto: firewall.ProtocolUDP, + proto: types.ProtocolUDP, sPort: nil, - dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, - direction: firewall.RuleDirectionIN, - action: firewall.ActionDrop, + dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: types.RuleDirectionIN, + action: types.ActionDrop, expectSet: false, }, { name: "TCP rule with source and destination ports", sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, destination: netip.MustParsePrefix("172.16.0.0/16"), - proto: firewall.ProtocolTCP, - sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, - dPort: &firewall.Port{Values: []int{22}}, - direction: firewall.RuleDirectionOUT, - action: firewall.ActionAccept, + proto: types.ProtocolTCP, + sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &types.Port{Values: []int{22}}, + direction: types.RuleDirectionOUT, + action: types.ActionAccept, expectSet: false, }, { name: "Drop all incoming traffic", sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, destination: netip.MustParsePrefix("192.168.0.0/24"), - proto: firewall.ProtocolALL, + proto: types.ProtocolALL, sPort: nil, dPort: nil, - direction: firewall.RuleDirectionIN, - action: firewall.ActionDrop, + direction: types.RuleDirectionIN, + action: types.ActionDrop, expectSet: false, }, } @@ -357,7 +357,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { expectedRule := genRouteFilteringRuleSpec(params) if tt.expectSet { - setName := firewall.GenerateSetName(tt.sources) + setName := types.GenerateSetName(tt.sources) params.SetName = setName expectedRule = genRouteFilteringRuleSpec(params) diff --git a/client/firewall/legacy/router.go b/client/firewall/legacy/router.go new file mode 100644 index 000000000..4cfc17ee1 --- /dev/null +++ b/client/firewall/legacy/router.go @@ -0,0 +1,35 @@ +package legacy + +import ( + "fmt" + + "github.com/sirupsen/logrus" +) + +// Router defines the interface for legacy management operations +type Router interface { + RemoveAllLegacyRouteRules() error + GetLegacyManagement() bool + SetLegacyManagement(bool) +} + +// SetLegacyRouter sets the route manager to use legacy management +func SetLegacyRouter(router Router, isLegacy bool) error { + oldLegacy := router.GetLegacyManagement() + + if oldLegacy != isLegacy { + router.SetLegacyManagement(isLegacy) + logrus.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to clean up the legacy rules + if !isLegacy && oldLegacy { + if err := router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + logrus.Debugf("Legacy routing rules removed") + } + + return nil +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go deleted file mode 100644 index dc4b737b6..000000000 --- a/client/firewall/manager/firewall.go +++ /dev/null @@ -1,196 +0,0 @@ -package manager - -import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "net" - "net/netip" - "sort" - "strings" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/statemanager" -) - -const ( - ForwardingFormatPrefix = "netbird-fwd-" - ForwardingFormat = "netbird-fwd-%s-%t" - PreroutingFormat = "netbird-prerouting-%s-%t" - NatFormat = "netbird-nat-%s-%t" -) - -// Rule abstraction should be implemented by each firewall manager -// -// Each firewall type for different OS can use different type -// of the properties to hold data of the created rule -type Rule interface { - // GetRuleID returns the rule id - GetRuleID() string -} - -// RuleDirection is the traffic direction which a rule is applied -type RuleDirection int - -const ( - // RuleDirectionIN applies to filters that handlers incoming traffic - RuleDirectionIN RuleDirection = iota - // RuleDirectionOUT applies to filters that handlers outgoing traffic - RuleDirectionOUT -) - -// Action is the action to be taken on a rule -type Action int - -const ( - // ActionAccept is the action to accept a packet - ActionAccept Action = iota - // ActionDrop is the action to drop a packet - ActionDrop -) - -// Manager is the high level abstraction of a firewall manager -// -// It declares methods which handle actions required by the -// Netbird client for ACL and routing functionality -type Manager interface { - Init(stateManager *statemanager.Manager) error - - // AllowNetbird allows netbird interface traffic - AllowNetbird() error - - // AddPeerFiltering adds a rule to the firewall - // - // If comment argument is empty firewall manager should set - // rule ID as comment for the rule - AddPeerFiltering( - ip net.IP, - proto Protocol, - sPort *Port, - dPort *Port, - action Action, - ipsetName string, - comment string, - ) ([]Rule, error) - - // DeletePeerRule from the firewall by rule definition - DeletePeerRule(rule Rule) error - - // IsServerRouteSupported returns true if the firewall supports server side routing operations - IsServerRouteSupported() bool - - AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) - - // DeleteRouteRule deletes a routing rule - DeleteRouteRule(rule Rule) error - - // AddNatRule inserts a routing NAT rule - AddNatRule(pair RouterPair) error - - // RemoveNatRule removes a routing NAT rule - RemoveNatRule(pair RouterPair) error - - // SetLegacyManagement sets the legacy management mode - SetLegacyManagement(legacy bool) error - - // Reset firewall to the default state - Reset(stateManager *statemanager.Manager) error - - // Flush the changes to firewall controller - Flush() error - - // AddDNATRule adds a DNAT rule - AddDNATRule(ForwardRule) (Rule, error) - - // DeleteDNATRule deletes a DNAT rule - // todo: do you need a string ID or the complete rule? - DeleteDNATRule(Rule) error -} - -func GenKey(format string, pair RouterPair) string { - return fmt.Sprintf(format, pair.ID, pair.Inverse) -} - -// LegacyManager defines the interface for legacy management operations -type LegacyManager interface { - RemoveAllLegacyRouteRules() error - GetLegacyManagement() bool - SetLegacyManagement(bool) -} - -// SetLegacyManagement sets the route manager to use legacy management -func SetLegacyManagement(router LegacyManager, isLegacy bool) error { - oldLegacy := router.GetLegacyManagement() - - if oldLegacy != isLegacy { - router.SetLegacyManagement(isLegacy) - log.Debugf("Set legacy management to %v", isLegacy) - } - - // client reconnected to a newer mgmt, we need to clean up the legacy rules - if !isLegacy && oldLegacy { - if err := router.RemoveAllLegacyRouteRules(); err != nil { - return fmt.Errorf("remove legacy routing rules: %v", err) - } - - log.Debugf("Legacy routing rules removed") - } - - return nil -} - -// GenerateSetName generates a unique name for an ipset based on the given sources. -func GenerateSetName(sources []netip.Prefix) string { - // sort for consistent naming - SortPrefixes(sources) - - var sourcesStr strings.Builder - for _, src := range sources { - sourcesStr.WriteString(src.String()) - } - - hash := sha256.Sum256([]byte(sourcesStr.String())) - shortHash := hex.EncodeToString(hash[:])[:8] - - return fmt.Sprintf("nb-%s", shortHash) -} - -// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix -func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { - if len(prefixes) == 0 { - return prefixes - } - - merged := []netip.Prefix{prefixes[0]} - for _, prefix := range prefixes[1:] { - last := merged[len(merged)-1] - if last.Contains(prefix.Addr()) { - // If the current prefix is contained within the last merged prefix, skip it - continue - } - if prefix.Contains(last.Addr()) { - // If the current prefix contains the last merged prefix, replace it - merged[len(merged)-1] = prefix - } else { - // Otherwise, add the current prefix to the merged list - merged = append(merged, prefix) - } - } - - return merged -} - -// SortPrefixes sorts the given slice of netip.Prefix in place. -// It sorts first by IP address, then by prefix length (most specific to least specific). -func SortPrefixes(prefixes []netip.Prefix) { - sort.Slice(prefixes, func(i, j int) bool { - addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) - if addrCmp != 0 { - return addrCmp < 0 - } - - // If IP addresses are the same, compare prefix lengths (longer prefixes first) - return prefixes[i].Bits() > prefixes[j].Bits() - }) -} diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go deleted file mode 100644 index 3f47d6679..000000000 --- a/client/firewall/manager/firewall_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package manager_test - -import ( - "net/netip" - "reflect" - "regexp" - "testing" - - "github.com/netbirdio/netbird/client/firewall/manager" -) - -func TestGenerateSetName(t *testing.T) { - t.Run("Different orders result in same hash", func(t *testing.T) { - prefixes1 := []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("10.0.0.0/8"), - } - prefixes2 := []netip.Prefix{ - netip.MustParsePrefix("10.0.0.0/8"), - netip.MustParsePrefix("192.168.1.0/24"), - } - - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) - - if result1 != result2 { - t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) - } - }) - - t.Run("Result format is correct", func(t *testing.T) { - prefixes := []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("10.0.0.0/8"), - } - - result := manager.GenerateSetName(prefixes) - - matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) - if err != nil { - t.Fatalf("Error matching regex: %v", err) - } - if !matched { - t.Errorf("Result format is incorrect: %s", result) - } - }) - - t.Run("Empty input produces consistent result", func(t *testing.T) { - result1 := manager.GenerateSetName([]netip.Prefix{}) - result2 := manager.GenerateSetName([]netip.Prefix{}) - - if result1 != result2 { - t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) - } - }) - - t.Run("IPv4 and IPv6 mixing", func(t *testing.T) { - prefixes1 := []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("2001:db8::/32"), - } - prefixes2 := []netip.Prefix{ - netip.MustParsePrefix("2001:db8::/32"), - netip.MustParsePrefix("192.168.1.0/24"), - } - - result1 := manager.GenerateSetName(prefixes1) - result2 := manager.GenerateSetName(prefixes2) - - if result1 != result2 { - t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) - } - }) -} - -func TestMergeIPRanges(t *testing.T) { - tests := []struct { - name string - input []netip.Prefix - expected []netip.Prefix - }{ - { - name: "Empty input", - input: []netip.Prefix{}, - expected: []netip.Prefix{}, - }, - { - name: "Single range", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - }, - }, - { - name: "Two non-overlapping ranges", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("10.0.0.0/8"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("10.0.0.0/8"), - }, - }, - { - name: "One range containing another", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/16"), - netip.MustParsePrefix("192.168.1.0/24"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/16"), - }, - }, - { - name: "One range containing another (different order)", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("192.168.0.0/16"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/16"), - }, - }, - { - name: "Overlapping ranges", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("192.168.1.128/25"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - }, - }, - { - name: "Overlapping ranges (different order)", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.128/25"), - netip.MustParsePrefix("192.168.1.0/24"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.1.0/24"), - }, - }, - { - name: "Multiple overlapping ranges", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/16"), - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("192.168.2.0/24"), - netip.MustParsePrefix("192.168.1.128/25"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/16"), - }, - }, - { - name: "Partially overlapping ranges", - input: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/23"), - netip.MustParsePrefix("192.168.1.0/24"), - netip.MustParsePrefix("192.168.2.0/25"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("192.168.0.0/23"), - netip.MustParsePrefix("192.168.2.0/25"), - }, - }, - { - name: "IPv6 ranges", - input: []netip.Prefix{ - netip.MustParsePrefix("2001:db8::/32"), - netip.MustParsePrefix("2001:db8:1::/48"), - netip.MustParsePrefix("2001:db8:2::/48"), - }, - expected: []netip.Prefix{ - netip.MustParsePrefix("2001:db8::/32"), - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := manager.MergeIPRanges(tt.input) - if !reflect.DeepEqual(result, tt.expected) { - t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected) - } - }) - } -} diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 8c1d89e68..9968696b2 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -15,7 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error { // rule ID as comment for the rule func (m *AclManager) AddPeerFiltering( ip net.IP, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, ipsetName string, comment string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { var ipset *nftables.Set if ipsetName != "" { var err error @@ -100,7 +100,7 @@ func (m *AclManager) AddPeerFiltering( } } - newRules := make([]firewall.Rule, 0, 2) + newRules := make([]types.Rule, 0, 2) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment) if err != nil { return nil, err @@ -111,7 +111,7 @@ func (m *AclManager) AddPeerFiltering( } // DeletePeerRule from the firewall by rule definition -func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { +func (m *AclManager) DeletePeerRule(rule types.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -234,10 +234,10 @@ func (m *AclManager) Flush() error { func (m *AclManager) addIOFiltering( ip net.IP, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, ipset *nftables.Set, comment string, ) (*Rule, error) { @@ -253,7 +253,7 @@ func (m *AclManager) addIOFiltering( var expressions []expr.Any - if proto != firewall.ProtocolALL { + if proto != types.ProtocolALL { expressions = append(expressions, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, @@ -341,9 +341,9 @@ func (m *AclManager) addIOFiltering( } switch action { - case firewall.ActionAccept: + case types.ActionAccept: expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept}) - case firewall.ActionDrop: + case types.ActionDrop: expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop}) } @@ -672,7 +672,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { +func generatePeerRuleId(ip net.IP, sPort *types.Port, dPort *types.Port, action types.Action, ipset *nftables.Set) string { rulesetID := ":" if sPort != nil { rulesetID += sPort.String() @@ -689,7 +689,7 @@ func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, a return "set:" + ipset.Name + rulesetID } -func encodePort(port firewall.Port) []byte { +func encodePort(port types.Port) []byte { bs := make([]byte, 2) binary.BigEndian.PutUint16(bs, uint16(port.Values[0])) return bs @@ -701,13 +701,13 @@ func ifname(n string) []byte { return b } -func protoToInt(protocol firewall.Protocol) (uint8, error) { +func protoToInt(protocol types.Protocol) (uint8, error) { switch protocol { - case firewall.ProtocolTCP: + case types.ProtocolTCP: return unix.IPPROTO_TCP, nil - case firewall.ProtocolUDP: + case types.ProtocolUDP: return unix.IPPROTO_UDP, nil - case firewall.ProtocolICMP: + case types.ProtocolICMP: return unix.IPPROTO_ICMP, nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 5b9b9c63a..7a05083a0 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -13,7 +13,8 @@ import ( "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/legacy" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -114,13 +115,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { // rule ID as comment for the rule func (m *Manager) AddPeerFiltering( ip net.IP, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, ipsetName string, comment string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -135,11 +136,11 @@ func (m *Manager) AddPeerFiltering( func (m *Manager) AddRouteFiltering( sources []netip.Prefix, destination netip.Prefix, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, -) (firewall.Rule, error) { + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, +) (types.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -151,7 +152,7 @@ func (m *Manager) AddRouteFiltering( } // DeletePeerRule from the firewall by rule definition -func (m *Manager) DeletePeerRule(rule firewall.Rule) error { +func (m *Manager) DeletePeerRule(rule types.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -159,7 +160,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { } // DeleteRouteRule deletes a routing rule -func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { +func (m *Manager) DeleteRouteRule(rule types.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -170,14 +171,14 @@ func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) AddNatRule(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair types.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() return m.router.AddNatRule(pair) } -func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair types.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -238,7 +239,7 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - return firewall.SetLegacyManagement(m.router, isLegacy) + return legacy.SetLegacyRouter(m.router, isLegacy) } // Reset firewall to the default state @@ -330,7 +331,7 @@ func (m *Manager) Flush() error { } // AddDNATRule adds a DNAT rule -func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { +func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) { r := &Rule{ ruleID: rule.GetRuleID(), } @@ -338,7 +339,7 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) } // DeleteDNATRule deletes a DNAT rule -func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { +func (m *Manager) DeleteDNATRule(rule types.Rule) error { return nil } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 9c9637282..7a19fd7e2 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sys/unix" - fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface" ) @@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) { testClient := &nftables.Conn{} - rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{53}}, fw.ActionDrop, "", "") + rule, err := manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{53}}, types.ActionDrop, "", "") require.NoError(t, err, "failed to add rule") err = manager.Flush() @@ -200,8 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) { ip := net.ParseIP("10.20.0.100") start := time.Now() for i := 0; i < testMax; i++ { - port := &fw.Port{Values: []int{1000 + i}} - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") + port := &types.Port{Values: []int{1000 + i}} + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") if i%100 == 0 { @@ -283,20 +283,20 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { }) ip := net.ParseIP("100.96.0.1") - _, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []int{80}}, fw.ActionAccept, "", "test rule") + _, err = manager.AddPeerFiltering(ip, types.ProtocolTCP, nil, &types.Port{Values: []int{80}}, types.ActionAccept, "", "test rule") require.NoError(t, err, "failed to add peer filtering rule") _, err = manager.AddRouteFiltering( []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, netip.MustParsePrefix("10.1.0.0/24"), - fw.ProtocolTCP, + types.ProtocolTCP, nil, - &fw.Port{Values: []int{443}}, - fw.ActionAccept, + &types.Port{Values: []int{443}}, + types.ActionAccept, ) require.NoError(t, err, "failed to add route filtering rule") - pair := fw.RouterPair{ + pair := types.RouterPair{ Source: netip.MustParsePrefix("192.168.1.0/24"), Destination: netip.MustParsePrefix("10.0.0.0/24"), Masquerade: true, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 34bc9a9bc..f351cf5d9 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -18,7 +18,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" nbnet "github.com/netbirdio/netbird/util/net" @@ -167,11 +167,11 @@ func (r *router) createContainers() error { func (r *router) AddRouteFiltering( sources []netip.Prefix, destination netip.Prefix, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, -) (firewall.Rule, error) { + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, +) (types.Rule, error) { ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { @@ -200,7 +200,7 @@ func (r *router) AddRouteFiltering( exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) // Handle protocol - if proto != firewall.ProtocolALL { + if proto != types.ProtocolALL { protoNum, err := protoToInt(proto) if err != nil { return nil, fmt.Errorf("convert protocol to number: %w", err) @@ -219,7 +219,7 @@ func (r *router) AddRouteFiltering( exprs = append(exprs, &expr.Counter{}) var verdict expr.VerdictKind - if action == firewall.ActionAccept { + if action == types.ActionAccept { verdict = expr.VerdictAccept } else { verdict = expr.VerdictDrop @@ -248,7 +248,7 @@ func (r *router) AddRouteFiltering( } func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { - setName := firewall.GenerateSetName(sources) + setName := types.GenerateSetName(sources) ref, err := r.ipsetCounter.Increment(setName, sources) if err != nil { return nil, fmt.Errorf("create or get ipset for sources: %w", err) @@ -270,7 +270,7 @@ func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr return exprs, nil } -func (r *router) DeleteRouteRule(rule firewall.Rule) error { +func (r *router) DeleteRouteRule(rule types.Rule) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } @@ -307,7 +307,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { // overlapping prefixes will result in an error, so we need to merge them - sources = firewall.MergeIPRanges(sources) + sources = mergeIPRanges(sources) set := &nftables.Set{ Name: setName, @@ -403,7 +403,7 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { } // AddNatRule appends a nftables rule pair to the nat chain -func (r *router) AddNatRule(pair firewall.RouterPair) error { +func (r *router) AddNatRule(pair types.RouterPair) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } @@ -420,7 +420,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { return fmt.Errorf("add nat rule: %w", err) } - if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + if err := r.addNatRule(types.GetInversePair(pair)); err != nil { return fmt.Errorf("add inverse nat rule: %w", err) } } @@ -433,7 +433,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { } // addNatRule inserts a nftables rule to the conn client flush queue -func (r *router) addNatRule(pair firewall.RouterPair) error { +func (r *router) addNatRule(pair types.RouterPair) error { sourceExp := generateCIDRMatcherExpressions(true, pair.Source) destExp := generateCIDRMatcherExpressions(false, pair.Destination) @@ -494,7 +494,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { }, ) - ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) + ruleKey := types.GenRuleKey(types.PreroutingFormat, pair) if _, exists := r.rules[ruleKey]; exists { if err := r.removeNatRule(pair); err != nil { @@ -584,7 +584,7 @@ func (r *router) addPostroutingRules() error { } // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls -func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { +func (r *router) addLegacyRouteRule(pair types.RouterPair) error { sourceExp := generateCIDRMatcherExpressions(true, pair.Source) destExp := generateCIDRMatcherExpressions(false, pair.Destination) @@ -597,7 +597,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic - ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + ruleKey := types.GenRuleKey(types.ForwardingFormat, pair) if _, exists := r.rules[ruleKey]; exists { if err := r.removeLegacyRouteRule(pair); err != nil { @@ -615,8 +615,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { } // removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls -func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) +func (r *router) removeLegacyRouteRule(pair types.RouterPair) error { + ruleKey := types.GenRuleKey(types.ForwardingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { if err := r.conn.DelRule(rule); err != nil { @@ -651,7 +651,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { var merr *multierror.Error for k, rule := range r.rules { - if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + if !strings.HasPrefix(k, types.ForwardingFormatPrefix) { continue } if err := r.conn.DelRule(rule); err != nil { @@ -829,7 +829,7 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error } // RemoveNatRule removes the prerouting mark rule -func (r *router) RemoveNatRule(pair firewall.RouterPair) error { +func (r *router) RemoveNatRule(pair types.RouterPair) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) } @@ -838,7 +838,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("remove prerouting rule: %w", err) } - if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + if err := r.removeNatRule(types.GetInversePair(pair)); err != nil { return fmt.Errorf("remove inverse prerouting rule: %w", err) } @@ -854,8 +854,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return nil } -func (r *router) removeNatRule(pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) +func (r *router) removeNatRule(pair types.RouterPair) error { + ruleKey := types.GenRuleKey(types.PreroutingFormat, pair) if rule, exists := r.rules[ruleKey]; exists { err := r.conn.DelRule(rule) @@ -931,7 +931,7 @@ func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any } } -func applyPort(port *firewall.Port, isSource bool) []expr.Any { +func applyPort(port *types.Port, isSource bool) []expr.Any { if port == nil { return nil } @@ -987,3 +987,27 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any { return exprs } + +func mergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) == 0 { + return prefixes + } + + merged := []netip.Prefix{prefixes[0]} + for _, prefix := range prefixes[1:] { + last := merged[len(merged)-1] + if last.Contains(prefix.Addr()) { + // If the current prefix is contained within the last merged prefix, skip it + continue + } + if prefix.Contains(last.Addr()) { + // If the current prefix contains the last merged prefix, replace it + merged[len(merged)-1] = prefix + } else { + // Otherwise, add the current prefix to the merged list + merged = append(merged, prefix) + } + } + + return merged +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index afc4d5c39..86497e69b 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "net/netip" "os/exec" + "reflect" "testing" "github.com/coreos/go-iptables/iptables" @@ -15,8 +16,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/firewall/types" ) const ( @@ -97,7 +98,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { testingExpression = append(testingExpression, sourceExp...) testingExpression = append(testingExpression, destExp...) - natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) + natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair) found := 0 for _, chain := range rtr.chains { if chain.Name == chainNamePrerouting { @@ -139,7 +140,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { require.NoError(t, err, "should add NAT rule") // Verify the rule was added - natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) + natRuleKey := types.GenRuleKey(types.PreroutingFormat, testCase.InputPair) found := false rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) require.NoError(t, err, "should list rules") @@ -209,22 +210,22 @@ func TestRouter_AddRouteFiltering(t *testing.T) { name string sources []netip.Prefix destination netip.Prefix - proto firewall.Protocol - sPort *firewall.Port - dPort *firewall.Port - direction firewall.RuleDirection - action firewall.Action + proto types.Protocol + sPort *types.Port + dPort *types.Port + direction types.RuleDirection + action types.Action expectSet bool }{ { name: "Basic TCP rule with single source", sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, destination: netip.MustParsePrefix("10.0.0.0/24"), - proto: firewall.ProtocolTCP, + proto: types.ProtocolTCP, sPort: nil, - dPort: &firewall.Port{Values: []int{80}}, - direction: firewall.RuleDirectionIN, - action: firewall.ActionAccept, + dPort: &types.Port{Values: []int{80}}, + direction: types.RuleDirectionIN, + action: types.ActionAccept, expectSet: false, }, { @@ -234,77 +235,77 @@ func TestRouter_AddRouteFiltering(t *testing.T) { netip.MustParsePrefix("192.168.0.0/16"), }, destination: netip.MustParsePrefix("10.0.0.0/8"), - proto: firewall.ProtocolUDP, - sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + proto: types.ProtocolUDP, + sPort: &types.Port{Values: []int{1024, 2048}, IsRange: true}, dPort: nil, - direction: firewall.RuleDirectionOUT, - action: firewall.ActionDrop, + direction: types.RuleDirectionOUT, + action: types.ActionDrop, expectSet: true, }, { name: "All protocols rule", sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, destination: netip.MustParsePrefix("0.0.0.0/0"), - proto: firewall.ProtocolALL, + proto: types.ProtocolALL, sPort: nil, dPort: nil, - direction: firewall.RuleDirectionIN, - action: firewall.ActionAccept, + direction: types.RuleDirectionIN, + action: types.ActionAccept, expectSet: false, }, { name: "ICMP rule", sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, destination: netip.MustParsePrefix("10.0.0.0/8"), - proto: firewall.ProtocolICMP, + proto: types.ProtocolICMP, sPort: nil, dPort: nil, - direction: firewall.RuleDirectionIN, - action: firewall.ActionAccept, + direction: types.RuleDirectionIN, + action: types.ActionAccept, expectSet: false, }, { name: "TCP rule with multiple source ports", sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, destination: netip.MustParsePrefix("192.168.0.0/16"), - proto: firewall.ProtocolTCP, - sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + proto: types.ProtocolTCP, + sPort: &types.Port{Values: []int{80, 443, 8080}}, dPort: nil, - direction: firewall.RuleDirectionOUT, - action: firewall.ActionAccept, + direction: types.RuleDirectionOUT, + action: types.ActionAccept, expectSet: false, }, { name: "UDP rule with single IP and port range", sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, destination: netip.MustParsePrefix("10.0.0.0/24"), - proto: firewall.ProtocolUDP, + proto: types.ProtocolUDP, sPort: nil, - dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, - direction: firewall.RuleDirectionIN, - action: firewall.ActionDrop, + dPort: &types.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: types.RuleDirectionIN, + action: types.ActionDrop, expectSet: false, }, { name: "TCP rule with source and destination ports", sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, destination: netip.MustParsePrefix("172.16.0.0/16"), - proto: firewall.ProtocolTCP, - sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, - dPort: &firewall.Port{Values: []int{22}}, - direction: firewall.RuleDirectionOUT, - action: firewall.ActionAccept, + proto: types.ProtocolTCP, + sPort: &types.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &types.Port{Values: []int{22}}, + direction: types.RuleDirectionOUT, + action: types.ActionAccept, expectSet: false, }, { name: "Drop all incoming traffic", sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, destination: netip.MustParsePrefix("192.168.0.0/24"), - proto: firewall.ProtocolALL, + proto: types.ProtocolALL, sPort: nil, dPort: nil, - direction: firewall.RuleDirectionIN, - action: firewall.ActionDrop, + direction: types.RuleDirectionIN, + action: types.ActionDrop, expectSet: false, }, } @@ -441,7 +442,7 @@ func TestNftablesCreateIpSet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - setName := firewall.GenerateSetName(tt.sources) + setName := types.GenerateSetName(tt.sources) set, err := r.createIpSet(setName, tt.sources) if err != nil { t.Logf("Failed to create IP set: %v", err) @@ -506,7 +507,7 @@ func TestNftablesCreateIpSet(t *testing.T) { } } -func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { +func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort, dPort *types.Port, direction types.RuleDirection, action types.Action, expectSet bool) { t.Helper() assert.NotNil(t, rule, "Rule should not be nil") @@ -515,21 +516,21 @@ func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, desti if expectSet { assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources") } else if len(sources) == 1 && sources[0].Bits() != 0 { - if direction == firewall.RuleDirectionIN { + if direction == types.RuleDirectionIN { assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0]) } else { assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0]) } } - if direction == firewall.RuleDirectionIN { + if direction == types.RuleDirectionIN { assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination) } else { assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination) } // Verify protocol - if proto != firewall.ProtocolALL { + if proto != types.ProtocolALL { assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto) } @@ -582,7 +583,7 @@ func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) b return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0 } -func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { +func containsPort(exprs []expr.Any, port *types.Port, isSource bool) bool { var offset uint32 = 2 // Default offset for destination port if isSource { offset = 0 // Offset for source port @@ -619,7 +620,7 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { return false } -func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { +func containsProtocol(exprs []expr.Any, proto types.Protocol) bool { var metaFound, cmpFound bool expectedProto, _ := protoToInt(proto) for _, e := range exprs { @@ -637,13 +638,13 @@ func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { return metaFound && cmpFound } -func containsAction(exprs []expr.Any, action firewall.Action) bool { +func containsAction(exprs []expr.Any, action types.Action) bool { for _, e := range exprs { if verdict, ok := e.(*expr.Verdict); ok { switch action { - case firewall.ActionAccept: + case types.ActionAccept: return verdict.Kind == expr.VerdictAccept - case firewall.ActionDrop: + case types.ActionDrop: return verdict.Kind == expr.VerdictDrop } } @@ -714,3 +715,121 @@ func deleteWorkTable() { } } } + +func TestMergeIPRanges(t *testing.T) { + tests := []struct { + name string + input []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Empty input", + input: []netip.Prefix{}, + expected: []netip.Prefix{}, + }, + { + name: "Single range", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Two non-overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + { + name: "One range containing another", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "One range containing another (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Overlapping ranges (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.128/25"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Multiple overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Partially overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + }, + { + name: "IPv6 ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::/48"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := mergeIPRanges(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 267e93efd..860e13531 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -3,7 +3,7 @@ package test import ( "net/netip" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + firewall "github.com/netbirdio/netbird/client/firewall/types" ) var ( diff --git a/client/firewall/manager/forward_rule.go b/client/firewall/types/forward_rule.go similarity index 86% rename from client/firewall/manager/forward_rule.go rename to client/firewall/types/forward_rule.go index 52bfbd9b0..815a973c1 100644 --- a/client/firewall/manager/forward_rule.go +++ b/client/firewall/types/forward_rule.go @@ -1,11 +1,10 @@ -package manager +package types import ( "fmt" "net/netip" ) -// ForwardRule todo figure out better place to this to avoid circular imports type ForwardRule struct { Protocol Protocol DestinationPort Port diff --git a/client/firewall/types/ipset.go b/client/firewall/types/ipset.go new file mode 100644 index 000000000..a5c8360b3 --- /dev/null +++ b/client/firewall/types/ipset.go @@ -0,0 +1,25 @@ +package types + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/netip" + "strings" +) + +// GenerateSetName generates a unique name for an ipset based on the given sources. +func GenerateSetName(sources []netip.Prefix) string { + // sort for consistent naming + SortPrefixes(sources) + + var sourcesStr strings.Builder + for _, src := range sources { + sourcesStr.WriteString(src.String()) + } + + hash := sha256.Sum256([]byte(sourcesStr.String())) + shortHash := hex.EncodeToString(hash[:])[:8] + + return fmt.Sprintf("nb-%s", shortHash) +} diff --git a/client/firewall/types/ipset_test.go b/client/firewall/types/ipset_test.go new file mode 100644 index 000000000..846e389f5 --- /dev/null +++ b/client/firewall/types/ipset_test.go @@ -0,0 +1,71 @@ +package types + +import ( + "net/netip" + "regexp" + "testing" +) + +func TestGenerateSetName(t *testing.T) { + t.Run("Different orders result in same hash", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := GenerateSetName(prefixes1) + result2 := GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) + } + }) + + t.Run("Result format is correct", func(t *testing.T) { + prefixes := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + + result := GenerateSetName(prefixes) + + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + if err != nil { + t.Fatalf("Error matching regex: %v", err) + } + if !matched { + t.Errorf("Result format is incorrect: %s", result) + } + }) + + t.Run("Empty input produces consistent result", func(t *testing.T) { + result1 := GenerateSetName([]netip.Prefix{}) + result2 := GenerateSetName([]netip.Prefix{}) + + if result1 != result2 { + t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) + } + }) + + t.Run("IPv4 and IPv6 mixing", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := GenerateSetName(prefixes1) + result2 := GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) + } + }) +} diff --git a/client/firewall/types/netip.go b/client/firewall/types/netip.go new file mode 100644 index 000000000..551ac8c9d --- /dev/null +++ b/client/firewall/types/netip.go @@ -0,0 +1,20 @@ +package types + +import ( + "net/netip" + "sort" +) + +// SortPrefixes sorts the given slice of netip.Prefix in place. +// It sorts first by IP address, then by prefix length (most specific to least specific). +func SortPrefixes(prefixes []netip.Prefix) { + sort.Slice(prefixes, func(i, j int) bool { + addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) + if addrCmp != 0 { + return addrCmp < 0 + } + + // If IP addresses are the same, compare prefix lengths (longer prefixes first) + return prefixes[i].Bits() > prefixes[j].Bits() + }) +} diff --git a/client/firewall/manager/port.go b/client/firewall/types/port.go similarity index 83% rename from client/firewall/manager/port.go rename to client/firewall/types/port.go index fa7e5a5dc..81e8ec000 100644 --- a/client/firewall/manager/port.go +++ b/client/firewall/types/port.go @@ -1,11 +1,10 @@ -package manager +package types import ( "strconv" ) // Port of the address for firewall rule -// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package type Port struct { // IsRange is true Values contains two values, the first is the start port, the second is the end port IsRange bool diff --git a/client/firewall/manager/protocol.go b/client/firewall/types/protocol.go similarity index 80% rename from client/firewall/manager/protocol.go rename to client/firewall/types/protocol.go index 67a1090a6..1085f2efe 100644 --- a/client/firewall/manager/protocol.go +++ b/client/firewall/types/protocol.go @@ -1,7 +1,6 @@ -package manager +package types // Protocol is the protocol of the port -// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package type Protocol string const ( diff --git a/client/firewall/manager/routerpair.go b/client/firewall/types/router_pair.go similarity index 96% rename from client/firewall/manager/routerpair.go rename to client/firewall/types/router_pair.go index 8c94b7dd4..e2e2cf47a 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/types/router_pair.go @@ -1,4 +1,4 @@ -package manager +package types import ( "net/netip" diff --git a/client/firewall/types/rule.go b/client/firewall/types/rule.go new file mode 100644 index 000000000..5238e6321 --- /dev/null +++ b/client/firewall/types/rule.go @@ -0,0 +1,43 @@ +package types + +import "fmt" + +const ( + PreroutingFormat = "netbird-prerouting-%s-%t" + NatFormat = "netbird-nat-%s-%t" + ForwardingFormat = "netbird-fwd-%s-%t" + ForwardingFormatPrefix = "netbird-fwd-" +) + +// Rule abstraction should be implemented by each firewall manager +// +// Each firewall type for different OS can use different type +// of the properties to hold data of the created rule +type Rule interface { + // GetRuleID returns the rule id + GetRuleID() string +} + +// RuleDirection is the traffic direction which a rule is applied +type RuleDirection int + +const ( + // RuleDirectionIN applies to filters that handlers incoming traffic + RuleDirectionIN RuleDirection = iota + // RuleDirectionOUT applies to filters that handlers outgoing traffic + RuleDirectionOUT +) + +// Action is the action to be taken on a rule +type Action int + +const ( + // ActionAccept is the action to accept a packet + ActionAccept Action = iota + // ActionDrop is the action to drop a packet + ActionDrop +) + +func GenRuleKey(format string, pair RouterPair) string { + return fmt.Sprintf(format, pair.ID, pair.Inverse) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index a217e9252..6f6cd0723 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -13,7 +13,8 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + firewall "github.com/netbirdio/netbird/client/firewall/interface" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" @@ -46,7 +47,7 @@ type Manager struct { wgNetwork *net.IPNet decoders sync.Pool wgIface IFaceMapper - nativeFirewall firewall.Manager + nativeFirewall firewall.Firewall mutex sync.RWMutex @@ -74,7 +75,7 @@ func Create(iface IFaceMapper) (*Manager, error) { return create(iface) } -func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { +func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Firewall) (*Manager, error) { mgr, err := create(iface) if err != nil { return nil, err @@ -134,7 +135,7 @@ func (m *Manager) IsServerRouteSupported() bool { } } -func (m *Manager) AddNatRule(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair types.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } @@ -142,7 +143,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error { } // RemoveNatRule removes a routing firewall rule -func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair types.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } @@ -155,19 +156,19 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { // rule ID as comment for the rule func (m *Manager) AddPeerFiltering( ip net.IP, - proto firewall.Protocol, - sPort *firewall.Port, - dPort *firewall.Port, - action firewall.Action, + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, _ string, comment string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { r := Rule{ id: uuid.New().String(), ip: ip, ipLayer: layers.LayerTypeIPv6, matchByIP: true, - drop: action == firewall.ActionDrop, + drop: action == types.ActionDrop, comment: comment, } if ipNormalized := ip.To4(); ipNormalized != nil { @@ -188,16 +189,16 @@ func (m *Manager) AddPeerFiltering( } switch proto { - case firewall.ProtocolTCP: + case types.ProtocolTCP: r.protoLayer = layers.LayerTypeTCP - case firewall.ProtocolUDP: + case types.ProtocolUDP: r.protoLayer = layers.LayerTypeUDP - case firewall.ProtocolICMP: + case types.ProtocolICMP: r.protoLayer = layers.LayerTypeICMPv4 if r.ipLayer == layers.LayerTypeIPv6 { r.protoLayer = layers.LayerTypeICMPv6 } - case firewall.ProtocolALL: + case types.ProtocolALL: r.protoLayer = layerTypeAll } @@ -207,17 +208,17 @@ func (m *Manager) AddPeerFiltering( } m.incomingRules[r.ip.String()][r.id] = r m.mutex.Unlock() - return []firewall.Rule{&r}, nil + return []types.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto types.Protocol, sPort *types.Port, dPort *types.Port, action types.Action) (types.Rule, error) { if m.nativeFirewall == nil { return nil, errRouteNotSupported } return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) } -func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { +func (m *Manager) DeleteRouteRule(rule types.Rule) error { if m.nativeFirewall == nil { return errRouteNotSupported } @@ -225,7 +226,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { } // DeletePeerRule from the firewall by rule definition -func (m *Manager) DeletePeerRule(rule firewall.Rule) error { +func (m *Manager) DeletePeerRule(rule types.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -254,12 +255,12 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) Flush() error { return nil } // AddDNATRule adds a DNAT rule -func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { +func (m *Manager) AddDNATRule(rule types.ForwardRule) (types.Rule, error) { return nil, fmt.Errorf("not implemented") } // DeleteDNATRule deletes a DNAT rule -func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { +func (m *Manager) DeleteDNATRule(rule types.Rule) error { return nil } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 4a210bf47..f20c0b9bd 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -12,7 +12,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" - fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface/device" ) @@ -90,8 +90,8 @@ func BenchmarkCoreFiltering(b *testing.B) { stateful: false, setupFunc: func(m *Manager) { // Single rule allowing all traffic - _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, - fw.ActionAccept, "", "allow all") + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolALL, nil, nil, + types.ActionAccept, "", "allow all") require.NoError(b, err) }, desc: "Baseline: Single 'allow all' rule without connection tracking", @@ -111,10 +111,10 @@ func BenchmarkCoreFiltering(b *testing.B) { // Add explicit rules matching return traffic pattern for i := 0; i < 1000; i++ { // Simulate realistic ruleset size ip := generateRandomIPs(1)[0] - _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, - &fw.Port{Values: []int{1024 + i}}, - &fw.Port{Values: []int{80}}, - fw.ActionAccept, "", "explicit return") + _, err := m.AddPeerFiltering(ip, types.ProtocolTCP, + &types.Port{Values: []int{1024 + i}}, + &types.Port{Values: []int{80}}, + types.ActionAccept, "", "explicit return") require.NoError(b, err) } }, @@ -125,8 +125,8 @@ func BenchmarkCoreFiltering(b *testing.B) { stateful: true, setupFunc: func(m *Manager) { // Add some basic rules but rely on state for established connections - _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, - fw.ActionDrop, "", "default drop") + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, nil, nil, + types.ActionDrop, "", "default drop") require.NoError(b, err) }, desc: "Connection tracking with established connections", @@ -587,10 +587,10 @@ func BenchmarkLongLivedConnections(b *testing.B) { // Setup initial state based on scenario if sc.rules { // Single rule to allow all return traffic from port 80 - _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, - &fw.Port{Values: []int{80}}, + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, + &types.Port{Values: []int{80}}, nil, - fw.ActionAccept, "", "return traffic") + types.ActionAccept, "", "return traffic") require.NoError(b, err) } @@ -678,10 +678,10 @@ func BenchmarkShortLivedConnections(b *testing.B) { // Setup initial state based on scenario if sc.rules { // Single rule to allow all return traffic from port 80 - _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, - &fw.Port{Values: []int{80}}, + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, + &types.Port{Values: []int{80}}, nil, - fw.ActionAccept, "", "return traffic") + types.ActionAccept, "", "return traffic") require.NoError(b, err) } @@ -796,10 +796,10 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { // Setup initial state based on scenario if sc.rules { - _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, - &fw.Port{Values: []int{80}}, + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, + &types.Port{Values: []int{80}}, nil, - fw.ActionAccept, "", "return traffic") + types.ActionAccept, "", "return traffic") require.NoError(b, err) } @@ -883,10 +883,10 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { }) if sc.rules { - _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, - &fw.Port{Values: []int{80}}, + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), types.ProtocolTCP, + &types.Port{Values: []int{80}}, nil, - fw.ActionAccept, "", "return traffic") + types.ActionAccept, "", "return traffic") require.NoError(b, err) } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 7e87443aa..7d4a24df2 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -11,7 +11,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" - fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" @@ -43,12 +43,12 @@ func TestManagerCreate(t *testing.T) { m, err := Create(ifaceMock) if err != nil { - t.Errorf("failed to create Manager: %v", err) + t.Errorf("failed to create Firewall: %v", err) return } if m == nil { - t.Error("Manager is nil") + t.Error("Firewall is nil") } } @@ -63,14 +63,14 @@ func TestManagerAddPeerFiltering(t *testing.T) { m, err := Create(ifaceMock) if err != nil { - t.Errorf("failed to create Manager: %v", err) + t.Errorf("failed to create Firewall: %v", err) return } ip := net.ParseIP("192.168.1.1") - proto := fw.ProtocolTCP - port := &fw.Port{Values: []int{80}} - action := fw.ActionDrop + proto := types.ProtocolTCP + port := &types.Port{Values: []int{80}} + action := types.ActionDrop comment := "Test rule" rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) @@ -97,14 +97,14 @@ func TestManagerDeleteRule(t *testing.T) { m, err := Create(ifaceMock) if err != nil { - t.Errorf("failed to create Manager: %v", err) + t.Errorf("failed to create Firewall: %v", err) return } ip := net.ParseIP("192.168.1.1") - proto := fw.ProtocolTCP - port := &fw.Port{Values: []int{80}} - action := fw.ActionDrop + proto := types.ProtocolTCP + port := &types.Port{Values: []int{80}} + action := types.ActionDrop comment := "Test rule 2" rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) @@ -138,7 +138,7 @@ func TestAddUDPPacketHook(t *testing.T) { tests := []struct { name string in bool - expDir fw.RuleDirection + expDir types.RuleDirection ip net.IP dPort uint16 hook func([]byte) bool @@ -147,7 +147,7 @@ func TestAddUDPPacketHook(t *testing.T) { { name: "Test Outgoing UDP Packet Hook", in: false, - expDir: fw.RuleDirectionOUT, + expDir: types.RuleDirectionOUT, ip: net.IPv4(10, 168, 0, 1), dPort: 8000, hook: func([]byte) bool { return true }, @@ -155,7 +155,7 @@ func TestAddUDPPacketHook(t *testing.T) { { name: "Test Incoming UDP Packet Hook", in: true, - expDir: fw.RuleDirectionIN, + expDir: types.RuleDirectionIN, ip: net.IPv6loopback, dPort: 9000, hook: func([]byte) bool { return false }, @@ -217,14 +217,14 @@ func TestManagerReset(t *testing.T) { m, err := Create(ifaceMock) if err != nil { - t.Errorf("failed to create Manager: %v", err) + t.Errorf("failed to create Firewall: %v", err) return } ip := net.ParseIP("192.168.1.1") - proto := fw.ProtocolTCP - port := &fw.Port{Values: []int{80}} - action := fw.ActionDrop + proto := types.ProtocolTCP + port := &types.Port{Values: []int{80}} + action := types.ActionDrop comment := "Test rule" _, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) @@ -235,7 +235,7 @@ func TestManagerReset(t *testing.T) { err = m.Reset(nil) if err != nil { - t.Errorf("failed to reset Manager: %v", err) + t.Errorf("failed to reset Firewall: %v", err) return } @@ -251,7 +251,7 @@ func TestNotMatchByIP(t *testing.T) { m, err := Create(ifaceMock) if err != nil { - t.Errorf("failed to create Manager: %v", err) + t.Errorf("failed to create Firewall: %v", err) return } m.wgNetwork = &net.IPNet{ @@ -260,8 +260,8 @@ func TestNotMatchByIP(t *testing.T) { } ip := net.ParseIP("0.0.0.0") - proto := fw.ProtocolUDP - action := fw.ActionAccept + proto := types.ProtocolUDP + action := types.ActionAccept comment := "Test rule" _, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment) @@ -304,7 +304,7 @@ func TestNotMatchByIP(t *testing.T) { } if err = m.Reset(nil); err != nil { - t.Errorf("failed to reset Manager: %v", err) + t.Errorf("failed to reset Firewall: %v", err) return } } @@ -319,7 +319,7 @@ func TestRemovePacketHook(t *testing.T) { // creating manager instance manager, err := Create(iface) if err != nil { - t.Fatalf("Failed to create Manager: %s", err) + t.Fatalf("Failed to create Firewall: %s", err) } defer func() { require.NoError(t, manager.Reset(nil)) @@ -463,8 +463,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ip := net.ParseIP("10.20.0.100") start := time.Now() for i := 0; i < testMax; i++ { - port := &fw.Port{Values: []int{1000 + i}} - _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") + port := &types.Port{Values: []int{1000 + i}} + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, types.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index 8ce73655d..48c3e6ebf 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -7,7 +7,7 @@ import ( "net/netip" "strconv" - "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" ) type RuleID string @@ -19,12 +19,12 @@ func (r RuleID) GetRuleID() string { func GenerateRouteRuleKey( sources []netip.Prefix, destination netip.Prefix, - proto manager.Protocol, - sPort *manager.Port, - dPort *manager.Port, - action manager.Action, + proto types.Protocol, + sPort *types.Port, + dPort *types.Port, + action types.Action, ) RuleID { - manager.SortPrefixes(sources) + types.SortPrefixes(sources) h := sha256.New() diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 0ade5d7ce..d0000e793 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -15,7 +15,8 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -30,17 +31,17 @@ type Manager interface { // DefaultManager uses firewall manager to handle type DefaultManager struct { - firewall firewall.Manager + firewall _interface.Firewall ipsetCounter int - peerRulesPairs map[id.RuleID][]firewall.Rule + peerRulesPairs map[id.RuleID][]types.Rule routeRules map[id.RuleID]struct{} mutex sync.Mutex } -func NewDefaultManager(fm firewall.Manager) *DefaultManager { +func NewDefaultManager(fm _interface.Firewall) *DefaultManager { return &DefaultManager{ firewall: fm, - peerRulesPairs: make(map[id.RuleID][]firewall.Rule), + peerRulesPairs: make(map[id.RuleID][]types.Rule), routeRules: make(map[id.RuleID]struct{}), } } @@ -132,7 +133,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { ) } - newRulePairs := make(map[id.RuleID][]firewall.Rule) + newRulePairs := make(map[id.RuleID][]types.Rule) ipsetByRuleSelectors := make(map[string]string) for _, r := range rules { @@ -251,7 +252,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, -) (id.RuleID, []firewall.Rule, error) { +) (id.RuleID, []types.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") @@ -267,13 +268,13 @@ func (d *DefaultManager) protoRuleToFirewallRule( return "", nil, fmt.Errorf("skipping firewall rule: %s", err) } - var port *firewall.Port + var port *types.Port if r.Port != "" { value, err := strconv.Atoi(r.Port) if err != nil { return "", nil, fmt.Errorf("invalid port, skipping firewall rule") } - port = &firewall.Port{ + port = &types.Port{ Values: []int{value}, } } @@ -283,7 +284,7 @@ func (d *DefaultManager) protoRuleToFirewallRule( return ruleID, rulesPair, nil } - var rules []firewall.Rule + var rules []types.Rule switch r.Direction { case mgmProto.RuleDirection_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") @@ -304,12 +305,12 @@ func (d *DefaultManager) protoRuleToFirewallRule( func (d *DefaultManager) addInRules( ip net.IP, - protocol firewall.Protocol, - port *firewall.Port, - action firewall.Action, + protocol types.Protocol, + port *types.Port, + action types.Action, ipsetName string, comment string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("add firewall rule: %w", err) @@ -320,12 +321,12 @@ func (d *DefaultManager) addInRules( func (d *DefaultManager) addOutRules( ip net.IP, - protocol firewall.Protocol, - port *firewall.Port, - action firewall.Action, + protocol types.Protocol, + port *types.Port, + action types.Action, ipsetName string, comment string, -) ([]firewall.Rule, error) { +) ([]types.Rule, error) { if shouldSkipInvertedRule(protocol, port) { return nil, nil } @@ -341,10 +342,10 @@ func (d *DefaultManager) addOutRules( // getPeerRuleID() returns unique ID for the rule based on its parameters. func (d *DefaultManager) getPeerRuleID( ip net.IP, - proto firewall.Protocol, + proto types.Protocol, direction int, - port *firewall.Port, - action firewall.Action, + port *types.Port, + action types.Action, comment string, ) id.RuleID { idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment @@ -491,7 +492,7 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) } -func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { +func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]types.Rule) { log.Debugf("rollback ACL to previous state") for _, rules := range newRulePairs { for _, rule := range rules { @@ -502,49 +503,49 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { } } -func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { +func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (types.Protocol, error) { switch protocol { case mgmProto.RuleProtocol_TCP: - return firewall.ProtocolTCP, nil + return types.ProtocolTCP, nil case mgmProto.RuleProtocol_UDP: - return firewall.ProtocolUDP, nil + return types.ProtocolUDP, nil case mgmProto.RuleProtocol_ICMP: - return firewall.ProtocolICMP, nil + return types.ProtocolICMP, nil case mgmProto.RuleProtocol_ALL: - return firewall.ProtocolALL, nil + return types.ProtocolALL, nil default: - return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) + return types.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) } } -func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) bool { - return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil +func shouldSkipInvertedRule(protocol types.Protocol, port *types.Port) bool { + return protocol == types.ProtocolALL || protocol == types.ProtocolICMP || port == nil } -func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { +func convertFirewallAction(action mgmProto.RuleAction) (types.Action, error) { switch action { case mgmProto.RuleAction_ACCEPT: - return firewall.ActionAccept, nil + return types.ActionAccept, nil case mgmProto.RuleAction_DROP: - return firewall.ActionDrop, nil + return types.ActionDrop, nil default: - return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) + return types.ActionDrop, fmt.Errorf("invalid action type: %d", action) } } -func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { +func convertPortInfo(portInfo *mgmProto.PortInfo) *types.Port { if portInfo == nil { return nil } if portInfo.GetPort() != 0 { - return &firewall.Port{ + return &types.Port{ Values: []int{int(portInfo.GetPort())}, } } if portInfo.GetRange() != nil { - return &firewall.Port{ + return &types.Port{ IsRange: true, Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 6049b4f48..bb83fa591 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -7,7 +7,7 @@ import ( "github.com/golang/mock/gomock" "github.com/netbirdio/netbird/client/firewall" - "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/acl/mocks" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -56,7 +56,7 @@ func TestDefaultManager(t *testing.T) { t.Errorf("create firewall: %v", err) return } - defer func(fw manager.Manager) { + defer func(fw _interface.Firewall) { _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) @@ -349,7 +349,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { t.Errorf("create firewall: %v", err) return } - defer func(fw manager.Manager) { + defer func(fw _interface.Firewall) { _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 968f2d398..f0d95cf2c 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -9,7 +9,8 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" + "github.com/netbirdio/netbird/client/firewall/types" ) const ( @@ -19,13 +20,13 @@ const ( ) type Manager struct { - firewall firewall.Manager + firewall _interface.Firewall - fwRules []firewall.Rule + fwRules []types.Rule dnsForwarder *DNSForwarder } -func NewManager(fw firewall.Manager) *Manager { +func NewManager(fw _interface.Firewall) *Manager { return &Manager{ firewall: fw, } @@ -79,7 +80,7 @@ func (m *Manager) Stop(ctx context.Context) error { } func (h *Manager) allowDNSFirewall() error { - dport := &firewall.Port{ + dport := &types.Port{ IsRange: false, Values: []int{ListenPort}, } @@ -88,7 +89,7 @@ func (h *Manager) allowDNSFirewall() error { return nil } - dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "") + dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, types.ProtocolUDP, nil, dport, types.ActionAccept, "", "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err diff --git a/client/internal/engine.go b/client/internal/engine.go index 40486168d..6ee2720b1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -25,7 +25,8 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall" - firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" @@ -169,7 +170,7 @@ type Engine struct { statusRecorder *peer.Status - firewall firewallManager.Manager + firewall _interface.Firewall routeManager routemanager.Manager acl acl.Manager dnsForwardMgr *dnsfwd.Manager @@ -504,15 +505,15 @@ func (e *Engine) initFirewall() error { } rosenpassPort := e.rpManager.GetAddress().Port - port := firewallManager.Port{Values: []int{rosenpassPort}} + port := types.Port{Values: []int{rosenpassPort}} // this rule is static and will be torn down on engine down by the firewall manager if _, err := e.firewall.AddPeerFiltering( net.IP{0, 0, 0, 0}, - firewallManager.ProtocolUDP, + types.ProtocolUDP, nil, &port, - firewallManager.ActionAccept, + types.ActionAccept, "", "", ); err != nil { @@ -540,10 +541,10 @@ func (e *Engine) blockLanAccess() { if _, err := e.firewall.AddRouteFiltering( []netip.Prefix{v4}, network, - firewallManager.ProtocolALL, + types.ProtocolALL, nil, nil, - firewallManager.ActionDrop, + types.ActionDrop, ); err != nil { merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err)) } @@ -1774,7 +1775,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error { } var merr *multierror.Error - forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules)) + forwardingRules := make([]types.ForwardRule, 0, len(rules)) for _, rule := range rules { proto, err := convertToFirewallProtocol(rule.GetProtocol()) if err != nil { @@ -1800,7 +1801,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error { continue } - forwardRule := firewallManager.ForwardRule{ + forwardRule := types.ForwardRule{ Protocol: proto, DestinationPort: *dstPortInfo, TranslatedAddress: translateIP, diff --git a/client/internal/engine_moc.go b/client/internal/engine_moc.go index 33989c0fa..65894889a 100644 --- a/client/internal/engine_moc.go +++ b/client/internal/engine_moc.go @@ -5,7 +5,7 @@ import ( log "github.com/sirupsen/logrus" - firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/internal/ingressgw" ) @@ -14,192 +14,192 @@ func (e *Engine) mocForwardRules() { e.ingressGatewayMgr = ingressgw.NewManager(e.firewall) } err := e.ingressGatewayMgr.Update( - []firewallManager.ForwardRule{ + []types.ForwardRule{ { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: false, Values: []int{10000}}, + DestinationPort: types.Port{IsRange: false, Values: []int{10000}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: false, Values: []int{20000}}, + TranslatedPort: types.Port{IsRange: false, Values: []int{20000}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10100, 10199}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10100, 10199}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20100, 20199}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20100, 20199}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10200, 10299}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10200, 10299}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20200, 20299}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20200, 20299}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10300, 10399}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10300, 10399}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20300, 20399}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20300, 20399}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10100, 10199}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10100, 10199}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20100, 20199}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20100, 20199}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10400, 10499}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10400, 10499}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20400, 20499}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20400, 20499}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10500, 10599}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10500, 10599}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20500, 20599}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20500, 20599}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10600, 10699}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10600, 10699}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20600, 20699}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20600, 20699}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10700, 10799}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10700, 10799}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20700, 20799}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20700, 20799}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10800, 10899}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10800, 10899}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20800, 20899}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20800, 20899}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{10900, 10999}}, + DestinationPort: types.Port{IsRange: true, Values: []int{10900, 10999}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{20900, 20999}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{20900, 20999}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11000, 11099}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11000, 11099}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21000, 21099}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21000, 21099}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11100, 11199}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11100, 11199}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21100, 21199}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21100, 21199}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11200, 11299}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11200, 11299}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21200, 21299}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21200, 21299}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11300, 11399}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11300, 11399}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21300, 21399}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21300, 21399}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11400, 11499}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11400, 11499}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21400, 21499}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21400, 21499}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11500, 11599}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11500, 11599}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21500, 21599}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21500, 21599}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11600, 11699}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11600, 11699}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21600, 21699}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21600, 21699}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11700, 11799}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11700, 11799}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21700, 21799}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21700, 21799}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11800, 11899}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11800, 11899}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21800, 21899}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21800, 21899}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{11900, 11999}}, + DestinationPort: types.Port{IsRange: true, Values: []int{11900, 11999}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{21900, 21999}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{21900, 21999}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12000, 12099}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12000, 12099}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22000, 22099}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22000, 22099}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12100, 12199}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12100, 12199}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22100, 22199}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22100, 22199}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12200, 12299}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12200, 12299}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22200, 22299}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22200, 22299}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12300, 12399}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12300, 12399}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22300, 22399}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22300, 22399}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12400, 12499}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12400, 12499}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22400, 22499}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22400, 22499}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12500, 12599}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12500, 12599}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22500, 22599}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22500, 22599}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12600, 12699}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12600, 12699}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22600, 22699}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22600, 22699}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12700, 12799}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12700, 12799}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22700, 22799}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22700, 22799}}, }, { Protocol: "tcp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12800, 12899}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12800, 12899}}, TranslatedAddress: netip.MustParseAddr("100.64.31.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22800, 22899}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22800, 22899}}, }, { Protocol: "udp", - DestinationPort: firewallManager.Port{IsRange: true, Values: []int{12900, 12999}}, + DestinationPort: types.Port{IsRange: true, Values: []int{12900, 12999}}, TranslatedAddress: netip.MustParseAddr("100.64.10.206"), - TranslatedPort: firewallManager.Port{IsRange: true, Values: []int{22900, 22999}}, + TranslatedPort: types.Port{IsRange: true, Values: []int{22900, 22999}}, }, }, ) diff --git a/client/internal/ingressgw/manager.go b/client/internal/ingressgw/manager.go index f7781754d..2c50fa129 100644 --- a/client/internal/ingressgw/manager.go +++ b/client/internal/ingressgw/manager.go @@ -8,29 +8,30 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" + "github.com/netbirdio/netbird/client/firewall/types" ) type RulePair struct { - firewallManager.ForwardRule - firewallManager.Rule + types.ForwardRule + types.Rule } type Manager struct { - firewallManager firewallManager.Manager + firewall _interface.Firewall rules map[string]RulePair // keys is the ID of the ForwardRule rulesMu sync.Mutex } -func NewManager(firewall firewallManager.Manager) *Manager { +func NewManager(firewall _interface.Firewall) *Manager { return &Manager{ - firewallManager: firewall, - rules: make(map[string]RulePair), + firewall: firewall, + rules: make(map[string]RulePair), } } -func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error { +func (h *Manager) Update(forwardRules []types.ForwardRule) error { h.rulesMu.Lock() defer h.rulesMu.Unlock() @@ -48,7 +49,7 @@ func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error { continue } - rule, err := h.firewallManager.AddDNATRule(fwdRule) + rule, err := h.firewall.AddDNATRule(fwdRule) if err != nil { mErr = multierror.Append(mErr, fmt.Errorf("failed to add forward rule '%s': %v", fwdRule.String(), err)) continue @@ -62,7 +63,7 @@ func (h *Manager) Update(forwardRules []firewallManager.ForwardRule) error { // Remove deleted rules for id, rulePair := range toDelete { - if err := h.firewallManager.DeleteDNATRule(rulePair.Rule); err != nil { + if err := h.firewall.DeleteDNATRule(rulePair.Rule); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err)) } delete(h.rules, id) @@ -78,18 +79,18 @@ func (h *Manager) Close() error { log.Infof("clean up all forward rules (%d)", len(h.rules)) var mErr *multierror.Error for _, rule := range h.rules { - if err := h.firewallManager.DeleteDNATRule(rule.Rule); err != nil { + if err := h.firewall.DeleteDNATRule(rule.Rule); err != nil { mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err)) } } return nberrors.FormatErrorOrNil(mErr) } -func (h *Manager) Rules() []firewallManager.ForwardRule { +func (h *Manager) Rules() []types.ForwardRule { h.rulesMu.Lock() defer h.rulesMu.Unlock() - rules := make([]firewallManager.ForwardRule, 0, len(h.rules)) + rules := make([]types.ForwardRule, 0, len(h.rules)) for _, rulePair := range h.rules { rules = append(rules, rulePair.ForwardRule) } diff --git a/client/internal/message_convert.go b/client/internal/message_convert.go index 6141bcee6..7e9f8d2aa 100644 --- a/client/internal/message_convert.go +++ b/client/internal/message_convert.go @@ -6,39 +6,39 @@ import ( "net" "net/netip" - firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/types" mgmProto "github.com/netbirdio/netbird/management/proto" ) -func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) { +func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (types.Protocol, error) { switch protocol { case mgmProto.RuleProtocol_TCP: - return firewallManager.ProtocolTCP, nil + return types.ProtocolTCP, nil case mgmProto.RuleProtocol_UDP: - return firewallManager.ProtocolUDP, nil + return types.ProtocolUDP, nil case mgmProto.RuleProtocol_ICMP: - return firewallManager.ProtocolICMP, nil + return types.ProtocolICMP, nil case mgmProto.RuleProtocol_ALL: - return firewallManager.ProtocolALL, nil + return types.ProtocolALL, nil default: - return firewallManager.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) + return types.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) } } // convertPortInfo todo: write validation for portInfo -func convertPortInfo(portInfo *mgmProto.PortInfo) *firewallManager.Port { +func convertPortInfo(portInfo *mgmProto.PortInfo) *types.Port { if portInfo == nil { return nil } if portInfo.GetPort() != 0 { - return &firewallManager.Port{ + return &types.Port{ Values: []int{int(portInfo.GetPort())}, } } if portInfo.GetRange() != nil { - return &firewallManager.Port{ + return &types.Port{ IsRange: true, Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 113afbb68..1b7386cf2 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -11,7 +11,7 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + firewall "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/relay" diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6f73fb166..9033b426e 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/netstack" @@ -44,7 +44,7 @@ type Manager interface { GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string - EnableServerRouter(firewall firewall.Manager) error + EnableServerRouter(firewall _interface.Firewall) error Stop(stateManager *statemanager.Manager) } @@ -214,7 +214,7 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector { return routeselector.NewRouteSelector() } -func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { +func (m *DefaultManager) EnableServerRouter(firewall _interface.Firewall) error { if m.disableServerRoutes { log.Info("server routes are disabled") return nil diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 64fdffceb..52412aef4 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -3,7 +3,7 @@ package routemanager import ( "context" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" @@ -78,7 +78,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList } -func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { +func (m *MockManager) EnableServerRouter(firewall _interface.Firewall) error { panic("implement me") } diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index e9cfa0826..a11bba313 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -6,7 +6,7 @@ import ( "context" "fmt" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/route" @@ -22,6 +22,6 @@ func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { return nil } -func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) { +func newServerRouter(context.Context, iface.IWGIface, _interface.Firewall, *peer.Status) (*serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index b60cb318e..3c1d06591 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -10,7 +10,8 @@ import ( log "github.com/sirupsen/logrus" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/interface" + "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -21,12 +22,12 @@ type serverRouter struct { mux sync.Mutex ctx context.Context routes map[route.ID]*route.Route - firewall firewall.Manager + firewall _interface.Firewall wgInterface iface.IWGIface statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { +func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall _interface.Firewall, statusRecorder *peer.Status) (*serverRouter, error) { return &serverRouter{ ctx: ctx, routes: make(map[route.ID]*route.Route), @@ -167,7 +168,7 @@ func (m *serverRouter) cleanUp() { m.statusRecorder.UpdateLocalPeerState(state) } -func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { +func routeToRouterPair(route *route.Route) (types.RouterPair, error) { // TODO: add ipv6 source := getDefaultPrefix(route.Network) @@ -177,7 +178,7 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { destination = getDefaultPrefix(destination) } - return firewall.RouterPair{ + return types.RouterPair{ ID: route.ID, Source: source, Destination: destination, diff --git a/client/server/forwardingrules.go b/client/server/forwardingrules.go index f86aca475..1580cac90 100644 --- a/client/server/forwardingrules.go +++ b/client/server/forwardingrules.go @@ -3,7 +3,7 @@ package server import ( "context" - firewall "github.com/netbirdio/netbird/client/firewall/manager" + firewall "github.com/netbirdio/netbird/client/firewall/types" "github.com/netbirdio/netbird/client/proto" )