[client] Fix bind exclusion routes (#4154)

This commit is contained in:
Viktor Liu
2025-07-21 12:13:21 +02:00
committed by GitHub
parent 40fdeda838
commit d6ed9c037e
25 changed files with 230 additions and 151 deletions

View File

@@ -61,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -138,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager
rpManager *rosenpass.Manager
@@ -409,12 +405,8 @@ func (e *Engine) Start() error {
DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes,
})
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
if err := e.routeManager.Init(); err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
}
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil
}

View File

@@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr,
})
_, _, err = engine.routeManager.Init()
err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
)
@@ -106,10 +105,6 @@ type Conn struct {
workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte
@@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
}
@@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler
@@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp
}
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
@@ -503,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
@@ -707,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true
}
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
}
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{

View File

@@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}
params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
}
// create new clientNetwork
client := &Watcher{

View File

@@ -44,7 +44,7 @@ import (
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
Init() error
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
TriggerSelection(route.HAMap)
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
}
// Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (m *DefaultManager) Init() error {
m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil
return nil
}
if err := m.sysOps.CleanupRouting(nil); err != nil {
@@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
if err != nil {
return nil, nil, fmt.Errorf("setup routing: %w", err)
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil
return nil
}
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {

View File

@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
StatusRecorder: statusRecorder,
})
_, _, err = routeManager.Init()
err = routeManager.Init()
require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil)

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
)
// MockManager is the mock instance of a route manager
@@ -23,8 +22,8 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager)
}
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
return nil, nil, nil
func (m *MockManager) Init() error {
return nil
}
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface

View File

@@ -33,4 +33,4 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}
}

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -56,6 +57,10 @@ type SysOps struct {
// seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems
seq atomic.Uint32
localSubnetsCache []*net.IPNet
localSubnetsCacheMu sync.RWMutex
localSubnetsCacheTime time.Time
}
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {

View File

@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return nil, nil, nil
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -10,6 +10,7 @@ import (
"net/netip"
"runtime"
"strconv"
"time"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
@@ -24,6 +25,8 @@ import (
nbnet "github.com/netbirdio/netbird/util/net"
)
const localSubnetsCacheTTL = 15 * time.Minute
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager)
if err := r.setupHooks(initAddresses, stateManager); err != nil {
return fmt.Errorf("setup hooks: %w", err)
}
return nil
}
// updateState updates state on every change so it will be persisted regularly
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
exitNextHop := Nexthop{
IP: nexthop.IP,
Intf: nexthop.Intf,
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
exitNextHop := nexthop
vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop
}
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
}
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
r.localSubnetsCacheMu.RLock()
cacheAge := time.Since(r.localSubnetsCacheTime)
subnets := r.localSubnetsCache
r.localSubnetsCacheMu.RUnlock()
if cacheAge > localSubnetsCacheTTL || subnets == nil {
r.localSubnetsCacheMu.Lock()
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
r.refreshLocalSubnetsCache()
}
subnets = r.localSubnetsCache
r.localSubnetsCacheMu.Unlock()
}
for _, subnet := range subnets {
if subnet.Contains(prefix.Addr().AsSlice()) {
return true, subnet
}
}
return false, nil
}
func (r *SysOps) refreshLocalSubnetsCache() {
localInterfaces, err := net.Interfaces()
if err != nil {
log.Errorf("Failed to get local interfaces: %v", err)
return false, nil
return
}
var newSubnets []*net.IPNet
for _, intf := range localInterfaces {
addrs, err := intf.Addrs()
if err != nil {
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
log.Errorf("Failed to convert address to IPNet: %v", addr)
continue
}
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
newSubnets = append(newSubnets, ipnet)
}
}
return false, nil
r.localSubnetsCache = newSubnets
r.localSubnetsCacheTime = time.Now()
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil
}
var merr *multierror.Error
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
}
}
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return ctx.Err()
}
var result *multierror.Error
var merr *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(result)
return nberrors.FormatErrorOrNil(merr)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return afterHook(connID)
})
return beforeHook, afterHook, nil
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
})
return nberrors.FormatErrorOrNil(merr)
}
func GetNextHop(ip netip.Addr) (Nexthop, error) {

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
})
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))

View File

@@ -10,14 +10,13 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager)
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules()
for _, rule := range rules {
if err := addRule(rule); err != nil {
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
return fmt.Errorf("%s: %w", rule.description, err)
}
}
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
}
originalSysctl = originalValues
return nil, nil, nil
return nil
}
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.

View File

@@ -18,10 +18,9 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager)
}

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const InfiniteLifetime = 0xffffffff
@@ -137,7 +136,7 @@ const (
RouteDeleted
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager)
}