Revert "Merge branch 'main' into feature/remote-debug"

This reverts commit 6d6333058c, reversing
changes made to 446aded1f7.
This commit is contained in:
aliamerj
2025-10-06 12:24:48 +03:00
parent 6d6333058c
commit ba7793ae7b
288 changed files with 3117 additions and 8952 deletions

View File

@@ -2,13 +2,11 @@ package dnsinterceptor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
@@ -28,8 +26,6 @@ import (
"github.com/netbirdio/netbird/route"
)
const dnsTimeout = 8 * time.Second
type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
@@ -247,7 +243,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
@@ -258,20 +254,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
@@ -583,16 +568,3 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
}
return
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""
}
peerState, err := d.statusRecorder.GetPeer(peerKey)
if err != nil {
return fmt.Sprintf(" (peer %s state error: %v)", peerKey[:8], err)
}
return fmt.Sprintf(" (peer %s)", nbdns.FormatPeerStatus(&peerState))
}

View File

@@ -36,9 +36,9 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -108,10 +108,6 @@ func NewManager(config ManagerConfig) *DefaultManager {
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
}
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
@@ -212,7 +208,7 @@ func (m *DefaultManager) Init() error {
return nil
}
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
if err := m.sysOps.CleanupRouting(nil); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -223,7 +219,7 @@ func (m *DefaultManager) Init() error {
ips := resolveURLsToIPs(initialAddresses)
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
@@ -289,15 +285,11 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
if runtime.GOOS == "windows" {
nbnet.SetVPNInterfaceName("")
}
}
m.mux.Lock()
@@ -376,11 +368,7 @@ func (m *DefaultManager) UpdateRoutes(
var merr *multierror.Error
if !m.disableClientRoutes {
// Update route selector based on management server's isSelected status
m.updateRouteSelectorFromManagement(clientRoutes)
filteredClientRoutes := m.routeSelector.FilterSelectedExitNodes(clientRoutes)
filteredClientRoutes := m.routeSelector.FilterSelected(clientRoutes)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update system routes: %w", err))
@@ -442,7 +430,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelectedExitNodes(networks)
networks = m.routeSelector.FilterSelected(networks)
m.notifier.OnNewRoutes(networks)
@@ -595,106 +583,3 @@ func resolveURLsToIPs(urls []string) []net.IP {
}
return ips
}
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
return
}
m.updateExitNodeSelections(exitNodeInfo)
m.logExitNodeUpdate(exitNodeInfo)
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID
userSelected []route.NetID
userDeselected []route.NetID
}
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
var info exitNodeInfo
for haID, routes := range clientRoutes {
if !m.isExitNodeRoute(routes) {
continue
}
netID := haID.NetID()
info.allIDs = append(info.allIDs, netID)
if m.routeSelector.HasUserSelectionForRoute(netID) {
m.categorizeUserSelection(netID, &info)
} else {
m.checkManagementSelection(routes, netID, &info)
}
}
return info
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {
if m.routeSelector.IsSelected(netID) {
info.userSelected = append(info.userSelected, netID)
} else {
info.userDeselected = append(info.userDeselected, netID)
}
}
func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID route.NetID, info *exitNodeInfo) {
for _, route := range routes {
if !route.SkipAutoApply {
info.selectedByManagement = append(info.selectedByManagement, netID)
break
}
}
}
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
m.deselectExitNodes(routesToDeselect)
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
}
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
var routesToDeselect []route.NetID
for _, netID := range allIDs {
if !m.routeSelector.HasUserSelectionForRoute(netID) {
routesToDeselect = append(routesToDeselect, netID)
}
}
return routesToDeselect
}
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
if len(routesToDeselect) == 0 {
return
}
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
if err != nil {
log.Warnf("Failed to deselect exit nodes: %v", err)
}
}
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
if len(selectedByManagement) == 0 {
return
}
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
if err != nil {
log.Warnf("Failed to select exit nodes: %v", err)
}
}
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
}

View File

@@ -190,15 +190,14 @@ func TestManagerUpdateRoutes(t *testing.T) {
name: "No Small Client Route Should Be Added",
inputRoutes: []*route.Route{
{
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
SkipAutoApply: false,
ID: "a",
NetID: "routeA",
Peer: remotePeerKey1,
Network: netip.MustParsePrefix("0.0.0.0/0"),
NetworkType: route.IPv4Network,
Metric: 9999,
Masquerade: false,
Enabled: true,
},
},
inputSerial: 1,

View File

@@ -12,11 +12,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
return nil
}

View File

@@ -3,6 +3,7 @@
package systemops
import (
"context"
"errors"
"fmt"
"net"
@@ -21,7 +22,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/net/hooks"
nbnet "github.com/netbirdio/netbird/util/net"
)
const localSubnetsCacheTTL = 15 * time.Minute
@@ -95,9 +96,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
hooks.RemoveWriteHooks()
hooks.RemoveCloseHooks()
hooks.RemoveAddressRemoveHooks()
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
@@ -289,7 +290,12 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
@@ -298,7 +304,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil
}
afterHook := func(connID hooks.ConnectionID) error {
afterHook := func(connID nbnet.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
@@ -311,20 +317,36 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
var merr *multierror.Error
for _, ip := range initAddresses {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
continue
}
if err := beforeHook("init", prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
if err := beforeHook("init", ip); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
}
}
hooks.AddWriteHook(beforeHook)
hooks.AddCloseHook(afterHook)
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
var merr *multierror.Error
for _, ip := range resolvedIPs {
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(merr)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}

View File

@@ -22,7 +22,6 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
@@ -144,11 +143,10 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
assert.NoError(t, r.CleanupRouting(nil))
})
intf, err := net.InterfaceByName(wgInterface.Name())
@@ -343,11 +341,10 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
assert.NoError(t, r.CleanupRouting(nil))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -487,11 +484,10 @@ func setupTestEnv(t *testing.T) {
})
r := NewSysOps(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
assert.NoError(t, r.CleanupRouting(nil))
})
index, err := net.InterfaceByName(wgInterface.Name())

View File

@@ -12,14 +12,14 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
r.mu.Lock()
defer r.mu.Unlock()

View File

@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
// IPRule contains IP rule information for debugging
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
if !advancedRouting {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager)
}
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if !advancedRouting {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
if !nbnet.AdvancedRouting() {
return r.cleanupRefCounter(stateManager)
}

View File

@@ -20,11 +20,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
return r.cleanupRefCounter(stateManager)
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {

View File

@@ -8,7 +8,6 @@ import (
"net/netip"
"os"
"runtime/debug"
"sort"
"strconv"
"sync"
"syscall"
@@ -20,16 +19,9 @@ import (
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func init() {
nbnet.GetBestInterfaceFunc = GetBestInterface
}
const (
InfiniteLifetime = 0xffffffff
)
const InfiniteLifetime = 0xffffffff
type RouteUpdateType int
@@ -85,14 +77,6 @@ type MIB_IPFORWARD_TABLE2 struct {
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
}
// candidateRoute represents a potential route for selection during route lookup
type candidateRoute struct {
interfaceIndex uint32
prefixLength uint8
routeMetric uint32
interfaceMetric int
}
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
type IP_ADDRESS_PREFIX struct {
Prefix SOCKADDR_INET
@@ -193,20 +177,11 @@ const (
RouteDeleted
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return nil
}
log.Infof("Using legacy routing setup with ref counters")
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return nil
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
return r.cleanupRefCounter(stateManager)
}
@@ -361,7 +336,7 @@ func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
if e1 != 0 {
return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
}
return fmt.Errorf("CreateIpForwardEntry2: code %d", windows.NTStatus(r1))
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1)
}
return nil
}
@@ -660,7 +635,10 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
if table != nil {
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
if ret != 0 {
log.Warnf("FreeMibTable failed with return code: %d", ret)
}
}
}
@@ -674,7 +652,8 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
detailed := buildWindowsDetailedRoute(entry)
if detailed != nil {
detailedRoutes = append(detailedRoutes, *detailed)
}
}
@@ -823,46 +802,6 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
return ip
}
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
var candidates []candidateRoute
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
for i := uint32(0); i < table.NumEntries; i++ {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
candidates = append(candidates, *candidate)
}
}
return candidates
}
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
// Returns nil if the route doesn't match the destination or should be skipped
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
return nil
}
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
return nil
}
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
return &candidateRoute{
interfaceIndex: entry.InterfaceIndex,
prefixLength: entry.DestinationPrefix.PrefixLength,
routeMetric: entry.Metric,
interfaceMetric: interfaceMetric,
}
}
// getInterfaceMetric retrieves the interface metric for a given interface and address family
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
if interfaceIndex == 0 {
@@ -882,76 +821,6 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
return int(ipInterfaceRow.Metric)
}
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
func sortRouteCandidates(candidates []candidateRoute) {
sort.Slice(candidates, func(i, j int) bool {
if candidates[i].prefixLength != candidates[j].prefixLength {
return candidates[i].prefixLength > candidates[j].prefixLength
}
if candidates[i].routeMetric != candidates[j].routeMetric {
return candidates[i].routeMetric < candidates[j].routeMetric
}
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
})
}
// GetBestInterface finds the best interface for reaching a destination,
// excluding the VPN interface to avoid routing loops.
//
// Route selection priority:
// 1. Longest prefix match (most specific route)
// 2. Lowest route metric
// 3. Lowest interface metric
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
var skipInterfaceIndex int
if vpnIntf != "" {
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
skipInterfaceIndex = iface.Index
} else {
// not critical, if we cannot get ahold of the interface then we won't need to skip it
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
}
}
table, err := getWindowsRoutingTable()
if err != nil {
return nil, fmt.Errorf("get routing table: %w", err)
}
defer freeWindowsRoutingTable(table)
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
if len(candidates) == 0 {
return nil, fmt.Errorf("no route to %s", dest)
}
// Sort routes: prefix length -> route metric -> interface metric
sortRouteCandidates(candidates)
for _, candidate := range candidates {
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
if err != nil {
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
continue
}
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
continue
}
if iface.Flags&net.FlagUp == 0 {
log.Debugf("interface %s is down, trying next route", iface.Name)
continue
}
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
return iface, nil
}
return nil, fmt.Errorf("no usable interface found for %s", dest)
}
// formatRouteAge formats the route age in seconds to a human-readable string
func formatRouteAge(ageSeconds uint32) string {
if ageSeconds == 0 {

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/client/net"
nbnet "github.com/netbirdio/netbird/util/net"
)
var (

View File

@@ -12,8 +12,18 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
prefix := netip.PrefixFrom(addr, addr.BitLen())
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
return prefix, nil
}

View File

@@ -13,6 +13,4 @@ var (
Defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
Defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
ExitNodeCIDR = "0.0.0.0/0"
)