mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Clean up bsd routes independently of the state file (#4688)
This commit is contained in:
@@ -106,7 +106,7 @@ type DefaultManager struct {
|
|||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(config.Context)
|
mCTX, cancel := context.WithCancel(config.Context)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
sysOps := systemops.New(config.WGInterface, notifier)
|
||||||
|
|
||||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||||
|
|||||||
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// FlushMarkedRoutes is a no-op on non-BSD platforms.
|
||||||
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
sysops := NewSysOps(nil, nil)
|
sysOps := New(nil, nil)
|
||||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
|
||||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
|
||||||
|
|
||||||
return sysops.refCounter.Flush()
|
return sysOps.refCounter.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ type SysOps struct {
|
|||||||
localSubnetsCacheTime time.Time
|
localSubnetsCacheTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
return &SysOps{
|
return &SysOps{
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
|||||||
_, intf = setupDummyInterface(t)
|
_, intf = setupDummyInterface(t)
|
||||||
nexthop = Nexthop{netip.Addr{}, intf}
|
nexthop = Nexthop{netip.Addr{}, intf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := New(nil, nil)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 1024; i++ {
|
for i := 0; i < 1024; i++ {
|
||||||
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
|
|||||||
|
|
||||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := New(nil, nil)
|
||||||
err = r.addToRouteTable(prefix, nexthop)
|
err = r.addToRouteTable(prefix, nexthop)
|
||||||
require.NoError(t, err, "Failed to add route to table")
|
require.NoError(t, err, "Failed to add route to table")
|
||||||
|
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
|
|
||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
|
|
||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
|
|||||||
assert.NoError(t, wgInterface.Close())
|
assert.NoError(t, wgInterface.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
|
|||||||
@@ -7,19 +7,39 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||||
|
)
|
||||||
|
|
||||||
|
var routeProtoFlag int
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
switch os.Getenv(envRouteProtoFlag) {
|
||||||
|
case "2":
|
||||||
|
routeProtoFlag = unix.RTF_PROTO2
|
||||||
|
case "3":
|
||||||
|
routeProtoFlag = unix.RTF_PROTO3
|
||||||
|
default:
|
||||||
|
routeProtoFlag = unix.RTF_PROTO1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
|
|||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||||
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
rib, err := retryFetchRIB()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetch routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
flushedCount := 0
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
rtMsg, ok := msg.(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
routeInfo, err := MsgToRoute(rtMsg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Skipping route flush: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nexthop := Nexthop{
|
||||||
|
IP: routeInfo.Gw,
|
||||||
|
Intf: routeInfo.Interface,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
flushedCount++
|
||||||
|
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if flushedCount > 0 {
|
||||||
|
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||||
}
|
}
|
||||||
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
|
|||||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||||
msg = &route.RouteMessage{
|
msg = &route.RouteMessage{
|
||||||
Type: action,
|
Type: action,
|
||||||
Flags: unix.RTF_UP,
|
Flags: unix.RTF_UP | routeProtoFlag,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
Seq: r.getSeq(),
|
Seq: r.getSeq(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
|
|||||||
data, err := os.ReadFile(m.filePath)
|
data, err := os.ReadFile(m.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
log.Debug("state file does not exist")
|
log.Debugf("state file %s does not exist", m.filePath)
|
||||||
return nil, nil // nolint:nilnil
|
return nil, nil // nolint:nilnil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("read state file: %w", err)
|
return nil, fmt.Errorf("read state file: %w", err)
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean up any remaining routes independently of the state file
|
||||||
|
if !nbnet.AdvancedRouting() {
|
||||||
|
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user