[client] Use platform-native routing APIs for freeBSD, macOS and Windows

This commit is contained in:
Viktor Liu
2025-06-04 16:28:58 +02:00
committed by GitHub
parent 87148c503f
commit ea4d13e96d
53 changed files with 1552 additions and 1046 deletions

View File

@@ -6,9 +6,10 @@ import (
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type Nexthop struct {
@@ -30,11 +31,16 @@ func (n Nexthop) String() string {
return fmt.Sprintf("%s @ %d (%s)", n.IP.String(), n.Intf.Index, n.Intf.Name)
}
type wgIface interface {
Address() wgaddr.Address
Name() string
}
type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface iface.WGIface
wgInterface wgIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
@@ -45,9 +51,27 @@ type SysOps struct {
notifier *notifier.Notifier
}
func NewSysOps(wgInterface iface.WGIface, notifier *notifier.Notifier) *SysOps {
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
}
}
func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr()
switch {
case
!addr.IsValid(),
addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsMulticast(),
addr.IsUnspecified() && prefix.Bits() != 0,
r.wgInterface.Address().Network.Contains(addr):
return vars.ErrRouteNotAllowed
}
return nil
}

View File

@@ -8,6 +8,8 @@ import (
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
@@ -33,7 +35,12 @@ func init() {
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil)
@@ -43,7 +50,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
@@ -59,7 +66,7 @@ func TestConcurrentRoutes(t *testing.T) {
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
@@ -119,18 +126,39 @@ func TestBits(t *testing.T) {
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return "lo0"
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
@@ -176,12 +204,40 @@ func fetchOriginalGateway() (net.IP, error) {
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -17,7 +17,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -106,59 +105,15 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, initialNextHop Nexthop) (Nexthop, error) {
if err := r.validateRoute(prefix); err != nil {
return Nexthop{}, err
}
addr := prefix.Addr()
if addr.IsUnspecified() {
return Nexthop{}, vars.ErrRouteNotAllowed
}
@@ -179,10 +134,7 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.WGIface
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
@@ -271,32 +223,7 @@ func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) er
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
return r.addToRouteTable(prefix, nextHop)
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
@@ -408,12 +335,8 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
return Nexthop{Intf: intf}, nil
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
@@ -457,32 +380,6 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := GetRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()

View File

@@ -3,23 +3,25 @@
package systemops
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"testing"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type dialer interface {
@@ -27,105 +29,370 @@ type dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddRemoveRoutes(t *testing.T) {
func TestAddVPNRoute(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
shouldRouteToWireguard bool
shouldBeRemoved bool
name string
prefix netip.Prefix
expectError bool
}{
{
name: "Should Add And Remove Route 100.66.120.0/24",
prefix: netip.MustParsePrefix("100.66.120.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
name: "IPv4 - Private network route",
prefix: netip.MustParsePrefix("10.10.100.0/24"),
},
{
name: "Should Not Add Or Remove Route 127.0.0.1/32",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
shouldRouteToWireguard: false,
shouldBeRemoved: false,
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("10.111.111.111/32"),
},
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Default route",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("fdb1:848a:7e16::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("fdb1:848a:7e16:a::b/128"),
},
{
name: "IPv6 Default route",
prefix: netip.MustParsePrefix("::/0"),
},
// IPv4 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
},
{
name: "IPv4 Link-local multicast",
prefix: netip.MustParsePrefix("224.0.0.251/32"),
expectError: true,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
},
{
name: "IPv4 Unspecified with prefix",
prefix: netip.MustParsePrefix("0.0.0.0/32"),
expectError: true,
},
// IPv6 addresses that should be rejected (matches validateRoute logic)
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
},
{
name: "IPv6 Link-local multicast",
prefix: netip.MustParsePrefix("ff02::1/128"),
expectError: true,
},
{
name: "IPv6 Interface-local multicast",
prefix: netip.MustParsePrefix("ff01::1/128"),
expectError: true,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
},
{
name: "IPv6 Unspecified with prefix",
prefix: netip.MustParsePrefix("::/128"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
},
{
name: "IPv4 WireGuard interface network subnet",
prefix: netip.MustParsePrefix("100.65.75.0/32"),
expectError: true,
},
}
for n, testCase := range testCases {
// todo resolve test execution on freebsd
if runtime.GOOS == "freebsd" {
t.Skip("skipping ", testCase.name, " on freebsd")
}
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
_, _, err = r.SetupRouting(nil, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
intf, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err)
// add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.expectError {
assert.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
// validate it's pointing to the WireGuard interface
require.NoError(t, err)
nextHop := getNextHop(t, testCase.prefix.Addr())
assert.Equal(t, wgInterface.Name(), nextHop.Intf.Name, "next hop interface should be WireGuard interface")
// remove route again
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err)
// validate it's gone
nextHop, err = GetNextHop(testCase.prefix.Addr())
require.True(t,
errors.Is(err, vars.ErrRouteNotFound) || err == nil && nextHop.Intf != nil && nextHop.Intf.Name != wgInterface.Name(),
"err: %v, next hop: %v", err, nextHop)
})
}
}
func getNextHop(t *testing.T, addr netip.Addr) Nexthop {
t.Helper()
if runtime.GOOS == "windows" || runtime.GOOS == "linux" {
nextHop, err := GetNextHop(addr)
if runtime.GOOS == "windows" && errors.Is(err, vars.ErrRouteNotFound) && addr.Is6() {
// TODO: Fix this test. It doesn't return the route when running in a windows github runner, but it is
// present in the route table.
t.Skip("Skipping windows test")
}
require.NoError(t, err)
require.NotNil(t, nextHop.Intf, "next hop interface should not be nil for %s", addr)
return nextHop
}
// GetNextHop for bsd is buggy and returns the wrong interface for the default route.
if addr.IsUnspecified() {
// On macOS, querying 0.0.0.0 returns the wrong interface
if addr.Is4() {
addr = netip.MustParseAddr("1.2.3.4")
} else {
addr = netip.MustParseAddr("2001:db8::1")
}
}
cmd := exec.Command("route", "-n", "get", addr.String())
if addr.Is6() {
cmd = exec.Command("route", "-n", "get", "-inet6", addr.String())
}
output, err := cmd.CombinedOutput()
t.Logf("route output: %s", output)
require.NoError(t, err, "%s failed")
lines := strings.Split(string(output), "\n")
var intf string
var gateway string
for _, line := range lines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "interface:") {
intf = strings.TrimSpace(strings.TrimPrefix(line, "interface:"))
} else if strings.HasPrefix(line, "gateway:") {
gateway = strings.TrimSpace(strings.TrimPrefix(line, "gateway:"))
}
}
require.NotEmpty(t, intf, "interface should be found in route output")
iface, err := net.InterfaceByName(intf)
require.NoError(t, err, "interface %s should exist", intf)
nexthop := Nexthop{Intf: iface}
if gateway != "" && gateway != "link#"+strconv.Itoa(iface.Index) {
addr, err := netip.ParseAddr(gateway)
if err == nil {
nexthop.IP = addr
}
}
return nexthop
}
func TestAddRouteToNonVPNIntf(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
expectError bool
errorType error
}{
{
name: "IPv4 RFC3927 test range",
prefix: netip.MustParsePrefix("198.51.100.0/24"),
},
{
name: "IPv4 Single host",
prefix: netip.MustParsePrefix("8.8.8.8/32"),
},
{
name: "IPv6 External network route",
prefix: netip.MustParsePrefix("2001:db8:1000::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2001:db8::1/128"),
},
{
name: "IPv6 Subnet",
prefix: netip.MustParsePrefix("2a05:d014:1f8d::/48"),
},
{
name: "IPv6 Single host",
prefix: netip.MustParsePrefix("2a05:d014:1f8d:7302:ebca:ec15:b24d:d07e/128"),
},
// Addresses that should be rejected
{
name: "IPv4 Loopback",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Link-local unicast",
prefix: netip.MustParsePrefix("169.254.1.1/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Multicast",
prefix: netip.MustParsePrefix("239.255.255.250/32"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 Unspecified",
prefix: netip.MustParsePrefix("0.0.0.0/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Loopback",
prefix: netip.MustParsePrefix("::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Link-local unicast",
prefix: netip.MustParsePrefix("fe80::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Multicast",
prefix: netip.MustParsePrefix("ff00::1/128"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv6 Unspecified",
prefix: netip.MustParsePrefix("::/0"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
{
name: "IPv4 WireGuard interface network overlap",
prefix: netip.MustParsePrefix("100.65.75.0/24"),
expectError: true,
errorType: vars.ErrRouteNotAllowed,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Fatalf("Failed to get IPv6 default route: %v", err)
}
var initialNextHop Nexthop
if testCase.prefix.Addr().Is6() {
initialNextHop = initialNextHopV6
} else {
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
initialNextHop = initialNextHopV4
}
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
nexthop, err := r.addRouteToNonVPNIntf(testCase.prefix, wgInterface, initialNextHop)
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
if testCase.expectError {
require.ErrorIs(t, err, vars.ErrRouteNotAllowed)
return
}
require.NoError(t, err)
t.Logf("Next hop for %s: %s", testCase.prefix, nexthop)
// Verify the route was added and points to non-VPN interface
currentNextHop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err)
assert.NotEqual(t, wgInterface.Name(), currentNextHop.Intf.Name, "Route should not point to VPN interface")
err = r.removeFromRouteTable(testCase.prefix, nexthop)
assert.NoError(t, err)
})
}
}
func TestGetNextHop(t *testing.T) {
if runtime.GOOS == "freebsd" {
t.Skip("skipping on freebsd")
}
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
defaultNh, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !nexthop.IP.IsValid() {
if !defaultNh.IP.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
@@ -133,7 +400,6 @@ func TestGetNextHop(t *testing.T) {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
@@ -141,213 +407,23 @@ func TestGetNextHop(t *testing.T) {
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := GetNextHop(testingPrefix.Addr())
nh, err := GetNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if !localIP.IP.IsValid() {
if nh.Intf == nil {
t.Fatal("should return a gateway for local network")
}
if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local IP should not match with gateway IP")
if nh.IP.String() == defaultNh.IP.String() {
t.Fatal("next hop IP should not match with default gateway IP")
}
if localIP.IP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun53%d", n),
Address: "100.65.75.2/24",
WGPort: 33100,
WGPrivKey: peerPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgInterface, err := iface.NewWGIFace(opts)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface, nil)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
switch {
case p.Addr().Is6():
continue
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
case runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast():
continue
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
case runtime.GOOS == "linux" && p.Addr().IsLoopback():
continue
// FreeBSD loopback 127/8 is not added to the routing table
case runtime.GOOS == "freebsd" && p.Addr().IsLoopback():
continue
default:
addressPrefixes = append(addressPrefixes, p.Masked())
}
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
if nh.Intf.Name != defaultNh.Intf.Name {
t.Fatalf("next hop interface name should match with default gateway interface name, got: %s, want: %s", nh.Intf.Name, defaultNh.Intf.Name)
}
}
@@ -384,11 +460,16 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
if err := r.AddVPNRoute(prefix, intf); err != nil {
if !errors.Is(err, syscall.EEXIST) && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("addVPNRoute should not return err: %v", err)
}
t.Logf("addVPNRoute %v returned: %v", prefix, err)
}
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
if err := r.RemoveVPNRoute(prefix, intf); err != nil && !errors.Is(err, vars.ErrRouteNotAllowed) {
t.Fatalf("removeVPNRoute should not return err: %v", err)
}
})
}
@@ -422,28 +503,10 @@ func setupTestEnv(t *testing.T) {
// 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string

View File

@@ -149,6 +149,10 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() {
return r.genericAddVPNRoute(prefix, intf)
}
@@ -172,6 +176,10 @@ func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
if !nbnet.AdvancedRouting() {
return r.genericRemoveVPNRoute(prefix, intf)
}
@@ -219,7 +227,7 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) {
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
prefix := netip.PrefixFrom(addr.Unmap(), ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
@@ -247,7 +255,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add route: %w", err)
}
@@ -270,7 +278,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
if err := netlink.RouteAdd(route); err != nil && !isOpErr(err) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}

View File

@@ -19,7 +19,6 @@ import (
)
var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
@@ -31,12 +30,6 @@ func init() {
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
},
{
name: "To more specific route (local) without custom dialer via physical interface",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...)
}

View File

@@ -11,10 +11,16 @@ import (
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if err := r.validateRoute(prefix); err != nil {
return err
}
return r.genericRemoveVPNRoute(prefix, intf)
}

View File

@@ -0,0 +1,268 @@
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
type mockWGIface struct {
address wgaddr.Address
name string
}
func (m *mockWGIface) Address() wgaddr.Address {
return m.address
}
func (m *mockWGIface) Name() string {
return m.name
}
func TestSysOps_validateRoute(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
}{
// Valid routes
{
name: "valid IPv4 route",
prefix: "192.168.1.0/24",
expectError: false,
},
{
name: "valid IPv6 route",
prefix: "2001:db8::/32",
expectError: false,
},
{
name: "valid single IPv4 host",
prefix: "8.8.8.8/32",
expectError: false,
},
{
name: "valid single IPv6 host",
prefix: "2001:4860:4860::8888/128",
expectError: false,
},
// Invalid routes - loopback
{
name: "IPv4 loopback",
prefix: "127.0.0.1/32",
expectError: true,
},
{
name: "IPv6 loopback",
prefix: "::1/128",
expectError: true,
},
// Invalid routes - link-local unicast
{
name: "IPv4 link-local unicast",
prefix: "169.254.1.1/32",
expectError: true,
},
{
name: "IPv6 link-local unicast",
prefix: "fe80::1/128",
expectError: true,
},
// Invalid routes - multicast
{
name: "IPv4 multicast",
prefix: "224.0.0.1/32",
expectError: true,
},
{
name: "IPv6 multicast",
prefix: "ff02::1/128",
expectError: true,
},
// Invalid routes - link-local multicast
{
name: "IPv4 link-local multicast",
prefix: "224.0.0.0/24",
expectError: true,
},
{
name: "IPv6 link-local multicast",
prefix: "ff02::/16",
expectError: true,
},
// Invalid routes - interface-local multicast (IPv6 only)
{
name: "IPv6 interface-local multicast",
prefix: "ff01::1/128",
expectError: true,
},
// Invalid routes - overlaps with WG interface network
{
name: "overlaps with WG network - exact match",
prefix: "10.0.0.0/24",
expectError: true,
},
{
name: "overlaps with WG network - subset",
prefix: "10.0.0.1/32",
expectError: true,
},
{
name: "overlaps with WG network - host in range",
prefix: "10.0.0.100/32",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s", tt.prefix)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s", tt.prefix)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s", tt.prefix)
}
})
}
}
func TestSysOps_validateRoute_SubnetOverlap(t *testing.T) {
wgNetwork := netip.MustParsePrefix("192.168.100.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wg0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
tests := []struct {
name string
prefix string
expectError bool
description string
}{
{
name: "identical subnet",
prefix: "192.168.100.0/24",
expectError: true,
description: "exact same network as WG interface",
},
{
name: "broader subnet containing WG network",
prefix: "192.168.0.0/16",
expectError: false,
description: "broader network that contains WG network should be allowed",
},
{
name: "host within WG network",
prefix: "192.168.100.50/32",
expectError: true,
description: "specific host within WG network",
},
{
name: "subnet within WG network",
prefix: "192.168.100.128/25",
expectError: true,
description: "smaller subnet within WG network",
},
{
name: "adjacent subnet - same /23",
prefix: "192.168.101.0/24",
expectError: false,
description: "adjacent subnet, no overlap",
},
{
name: "adjacent subnet - different /16",
prefix: "192.167.100.0/24",
expectError: false,
description: "different network, no overlap",
},
{
name: "WG network broadcast address",
prefix: "192.168.100.255/32",
expectError: true,
description: "broadcast address of WG network",
},
{
name: "WG network first usable",
prefix: "192.168.100.1/32",
expectError: true,
description: "first usable address in WG network",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prefix, err := netip.ParsePrefix(tt.prefix)
require.NoError(t, err, "Failed to parse test prefix %s", tt.prefix)
err = sysOps.validateRoute(prefix)
if tt.expectError {
require.Error(t, err, "validateRoute() expected error for %s (%s)", tt.prefix, tt.description)
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for %s (%s)", tt.prefix, tt.description)
} else {
assert.NoError(t, err, "validateRoute() expected no error for %s (%s)", tt.prefix, tt.description)
}
})
}
}
func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
wgNetwork := netip.MustParsePrefix("10.0.0.0/24")
mockWG := &mockWGIface{
address: wgaddr.Address{
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "wt0",
}
sysOps := &SysOps{
wgInterface: mockWG,
notifier: &notifier.Notifier{},
}
var invalidPrefix netip.Prefix
err := sysOps.validateRoute(invalidPrefix)
require.Error(t, err, "validateRoute() expected error for invalid prefix")
assert.Equal(t, vars.ErrRouteNotAllowed, err, "validateRoute() expected ErrRouteNotAllowed for invalid prefix")
}

View File

@@ -3,15 +3,19 @@
package systemops
import (
"errors"
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"strconv"
"syscall"
"time"
"unsafe"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -26,48 +30,16 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop)
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop)
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
}
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet"
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
if !prefix.IsValid() {
return fmt.Errorf("invalid prefix: %s", prefix)
}
expBackOff := backoff.NewExponentialBackOff()
@@ -75,9 +47,157 @@ func retryRouteCmd(args []string) error {
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
}
return fmt.Errorf("%s route for %s: %w", a, prefix, err)
}
return nil
}
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{
Type: action,
Flags: unix.RTF_UP,
Version: unix.RTM_VERSION,
Seq: 1,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
addrs[unix.RTAX_DST], err = addrToRouteAddr(prefix.Addr())
if err != nil {
return nil, fmt.Errorf("build destination address for %s: %w", prefix.Addr(), err)
}
if prefix.IsSingleIP() {
msg.Flags |= unix.RTF_HOST
} else {
addrs[unix.RTAX_NETMASK], err = prefixToRouteNetmask(prefix)
if err != nil {
return nil, fmt.Errorf("build netmask for %s: %w", prefix, err)
}
}
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
addrs[unix.RTAX_GATEWAY], err = addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return nil, fmt.Errorf("build gateway IP address for %s: %w", nexthop.IP, err)
}
} else if nexthop.Intf != nil {
msg.Index = nexthop.Intf.Index
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {
return &route.Inet4Addr{IP: addr.As4()}, nil
}
if addr.Zone() == "" {
return &route.Inet6Addr{IP: addr.As16()}, nil
}
var zone int
// zone can be either a numeric zone ID or an interface name.
if z, err := strconv.Atoi(addr.Zone()); err == nil {
zone = z
} else {
iface, err := net.InterfaceByName(addr.Zone())
if err != nil {
return nil, fmt.Errorf("resolve zone '%s': %w", addr.Zone(), err)
}
zone = iface.Index
}
return &route.Inet6Addr{IP: addr.As16(), ZoneID: zone}, nil
}
func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
m := net.CIDRMask(bits, 32)
var maskBytes [4]byte
copy(maskBytes[:], m)
return &route.Inet4Addr{IP: maskBytes}, nil
}
if prefix.Addr().Is6() {
m := net.CIDRMask(bits, 128)
var maskBytes [16]byte
copy(maskBytes[:], m)
return &route.Inet6Addr{IP: maskBytes}, nil
}
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
}

View File

@@ -1,5 +1,3 @@
//go:build windows
package systemops
import (
@@ -9,9 +7,8 @@ import (
"net"
"net/netip"
"os"
"os/exec"
"runtime/debug"
"strconv"
"strings"
"sync"
"syscall"
"time"
@@ -21,11 +18,12 @@ import (
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
)
const InfiniteLifetime = 0xffffffff
type RouteUpdateType int
// RouteUpdate represents a change in the routing table.
@@ -58,9 +56,13 @@ type MSFT_NetRoute struct {
AddressFamily uint16
}
// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
// luid represents a locally unique identifier for network interfaces
type luid uint64
// MIB_IPFORWARD_ROW2 represents a route entry in the routing table.
// It is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
type MIB_IPFORWARD_ROW2 struct {
InterfaceLuid uint64
InterfaceLuid luid
InterfaceIndex uint32
DestinationPrefix IP_ADDRESS_PREFIX
NextHop SOCKADDR_INET_NEXTHOP
@@ -108,9 +110,14 @@ type SOCKADDR_INET_NEXTHOP struct {
type MIB_NOTIFICATION_TYPE int32
var (
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
prefixList []netip.Prefix
lastUpdate time.Time
@@ -139,6 +146,8 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
log.Debugf("Adding route to %s via %s", prefix, nexthop)
// if we don't have an interface but a zone, extract the interface index from the zone
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
@@ -147,23 +156,187 @@ func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
nexthop.Intf = &net.Interface{Index: zone}
}
return addRouteCmd(prefix, nexthop)
return addRoute(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
log.Debugf("Removing route to %s via %s", prefix, nexthop)
return deleteRoute(prefix, nexthop)
}
// setupRouteEntry prepares a route entry with common configuration
func setupRouteEntry(prefix netip.Prefix, nexthop Nexthop) (*MIB_IPFORWARD_ROW2, error) {
route := &MIB_IPFORWARD_ROW2{}
initializeIPForwardEntry(route)
// Convert interface index to luid if interface is specified
if nexthop.Intf != nil {
var luid luid
if err := convertInterfaceIndexToLUID(uint32(nexthop.Intf.Index), &luid); err != nil {
return nil, fmt.Errorf("convert interface index to luid: %w", err)
}
route.InterfaceLuid = luid
route.InterfaceIndex = uint32(nexthop.Intf.Index)
}
routeCmd := uspfilter.GetSystem32Command("route")
if err := setDestinationPrefix(&route.DestinationPrefix, prefix); err != nil {
return nil, fmt.Errorf("set destination prefix: %w", err)
}
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if nexthop.IP.IsValid() {
if err := setNextHop(&route.NextHop, nexthop.IP); err != nil {
return nil, fmt.Errorf("set next hop: %w", err)
}
}
if err != nil {
return fmt.Errorf("remove route: %w", err)
return route, nil
}
// addRoute adds a route using Windows iphelper APIs
func addRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in addRoute: %v, stack trace: %s", r, debug.Stack())
}
}()
route, setupErr := setupRouteEntry(prefix, nexthop)
if setupErr != nil {
return fmt.Errorf("setup route entry: %w", setupErr)
}
route.Metric = 1
route.ValidLifetime = InfiniteLifetime
route.PreferredLifetime = InfiniteLifetime
return createIPForwardEntry2(route)
}
// deleteRoute deletes a route using Windows iphelper APIs
func deleteRoute(prefix netip.Prefix, nexthop Nexthop) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in deleteRoute: %v, stack trace: %s", r, debug.Stack())
}
}()
route, setupErr := setupRouteEntry(prefix, nexthop)
if setupErr != nil {
return fmt.Errorf("setup route entry: %w", setupErr)
}
if err := getIPForwardEntry2(route); err != nil {
return fmt.Errorf("get route entry: %w", err)
}
return deleteIPForwardEntry2(route)
}
// setDestinationPrefix sets the destination prefix in the route structure
func setDestinationPrefix(prefix *IP_ADDRESS_PREFIX, dest netip.Prefix) error {
addr := dest.Addr()
prefix.PrefixLength = uint8(dest.Bits())
if addr.Is4() {
prefix.Prefix.sin6_family = windows.AF_INET
ip4 := addr.As4()
binary.BigEndian.PutUint32(prefix.Prefix.data[:4],
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
return nil
}
if addr.Is6() {
prefix.Prefix.sin6_family = windows.AF_INET6
ip6 := addr.As16()
copy(prefix.Prefix.data[4:20], ip6[:])
if zone := addr.Zone(); zone != "" {
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
binary.BigEndian.PutUint32(prefix.Prefix.data[20:24], uint32(scopeID))
}
}
return nil
}
return fmt.Errorf("invalid address family")
}
// setNextHop sets the next hop address in the route structure
func setNextHop(nextHop *SOCKADDR_INET_NEXTHOP, addr netip.Addr) error {
if addr.Is4() {
nextHop.sin6_family = windows.AF_INET
ip4 := addr.As4()
binary.BigEndian.PutUint32(nextHop.data[:4],
uint32(ip4[0])<<24|uint32(ip4[1])<<16|uint32(ip4[2])<<8|uint32(ip4[3]))
return nil
}
if addr.Is6() {
nextHop.sin6_family = windows.AF_INET6
ip6 := addr.As16()
copy(nextHop.data[4:20], ip6[:])
// Handle zone if present
if zone := addr.Zone(); zone != "" {
if scopeID, err := strconv.ParseUint(zone, 10, 32); err == nil {
binary.BigEndian.PutUint32(nextHop.data[20:24], uint32(scopeID))
}
}
return nil
}
return fmt.Errorf("invalid address family")
}
// Windows API wrappers
func createIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("CreateIpForwardEntry2: %w", e1)
}
return fmt.Errorf("CreateIpForwardEntry2: code %d", r1)
}
return nil
}
func deleteIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("DeleteIpForwardEntry2: %w", e1)
}
return fmt.Errorf("DeleteIpForwardEntry2: code %d", r1)
}
return nil
}
func getIPForwardEntry2(route *MIB_IPFORWARD_ROW2) error {
r1, _, e1 := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("GetIpForwardEntry2: %w", e1)
}
return fmt.Errorf("GetIpForwardEntry2: code %d", r1)
}
return nil
}
// https://learn.microsoft.com/en-us/windows/win32/api/netioapi/nf-netioapi-initializeipforwardentry
func initializeIPForwardEntry(route *MIB_IPFORWARD_ROW2) {
// Does not return anything. Trying to handle the error might return an uninitialized value.
_, _, _ = syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route)))
}
func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *luid) error {
r1, _, e1 := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(),
uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)))
if r1 != 0 {
if e1 != 0 {
return fmt.Errorf("ConvertInterfaceIndexToLuid: %w", e1)
}
return fmt.Errorf("ConvertInterfaceIndexToLuid: code %d", r1)
}
return nil
}
@@ -319,7 +492,7 @@ func cancelMibChangeNotify2(handle windows.Handle) error {
}
// GetRoutesFromTable returns the current routing table from with prefixes only.
// It ccaches the result for 2 seconds to avoid blocking the caller.
// It caches the result for 2 seconds to avoid blocking the caller.
func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
@@ -388,35 +561,6 @@ func GetRoutes() ([]Route, error) {
return routes, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() {
ip := nexthop.IP.WithZone("")
args = append(args, ip.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -5,18 +5,23 @@ import (
"encoding/json"
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
var expectedExtInt = "Ethernet1"
var (
expectedExternalInt = "Ethernet1"
expectedVPNint = "wgtest0"
)
type RouteInfo struct {
NextHop string `json:"nexthop"`
@@ -43,8 +48,6 @@ type testCase struct {
dialer dialer
}
var expectedVPNint = "wgtest0"
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
@@ -52,14 +55,14 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "128.0.0.0/1",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedDestPrefix: "192.0.2.1/32",
expectedInterface: expectedExtInt,
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
},
@@ -67,24 +70,15 @@ var testCases = []testCase{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedDestPrefix: "10.0.0.2/32",
expectedInterface: expectedExtInt,
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedDestPrefix: "172.16.0.2/32",
expectedInterface: expectedExtInt,
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
},
{
@@ -93,7 +87,7 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "172.16.0.0/12",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
@@ -103,22 +97,14 @@ var testCases = []testCase{
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "10.10.0.0/24",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "127.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
},
}
func TestRouting(t *testing.T) {
log.SetLevel(log.DebugLevel)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
@@ -129,7 +115,7 @@ func TestRouting(t *testing.T) {
require.NoError(t, err, "Failed to fetch interface IP")
output := testRoute(t, tc.destination, tc.dialer)
if tc.expectedInterface == expectedExtInt {
if tc.expectedInterface == expectedExternalInt {
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
} else {
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
@@ -242,19 +228,23 @@ func setupDummyInterfacesAndRoutes(t *testing.T) {
func addDummyRoute(t *testing.T, dstCIDR string) {
t.Helper()
script := fmt.Sprintf(`New-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -PolicyStore ActiveStore`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
prefix, err := netip.ParsePrefix(dstCIDR)
if err != nil {
t.Logf("Failed to add dummy route: %v\nOutput: %s", err, output)
t.FailNow()
t.Fatalf("Failed to parse destination CIDR %s: %v", dstCIDR, err)
}
nexthop := Nexthop{
Intf: &net.Interface{Index: 1},
}
if err = addRoute(prefix, nexthop); err != nil {
t.Fatalf("Failed to add dummy route: %v", err)
}
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetRoute -DestinationPrefix "%s" -InterfaceIndex 1 -Confirm:$false`, dstCIDR)
output, err := exec.Command("powershell", "-Command", script).CombinedOutput()
err := deleteRoute(prefix, nexthop)
if err != nil {
t.Logf("Failed to remove dummy route: %v\nOutput: %s", err, output)
t.Logf("Failed to remove dummy route: %v", err)
}
})
}