Release 0.28.0 (#2092)

* compile client under freebsd (#1620)

Compile netbird client under freebsd and now support netstack and userspace modes.
Refactoring linux specific code to share same code with FreeBSD, move to *_unix.go files.

Not implemented yet:

Kernel mode not supported
DNS probably does not work yet
Routing also probably does not work yet
SSH support did not tested yet
Lack of test environment for freebsd (dedicated VM for github runners under FreeBSD required)
Lack of tests for freebsd specific code
info reporting need to review and also implement, for example OS reported as GENERIC instead of FreeBSD (lack of FreeBSD icon in management interface)
Lack of proper client setup under FreeBSD
Lack of FreeBSD port/package

* Add DNS routes (#1943)

Given domains are resolved periodically and resolved IPs are replaced with the new ones. Unless the flag keep_route is set to true, then only new ones are added.
This option is helpful if there are long-running connections that might still point to old IP addresses from changed DNS records.

* Add process posture check (#1693)

Introduces a process posture check to validate the existence and active status of specific binaries on peer systems. The check ensures that files are present at specified paths, and that corresponding processes are running. This check supports Linux, Windows, and macOS systems.


Co-authored-by: Evgenii <mail@skillcoder.com>
Co-authored-by: Pascal Fischer <pascal@netbird.io>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Maycon Santos
2024-06-13 13:24:24 +02:00
committed by GitHub
parent 95299be52d
commit 4fec709bb1
149 changed files with 6509 additions and 2710 deletions

View File

@@ -0,0 +1,18 @@
//go:build darwin || dragonfly || netbsd || openbsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
return true
}
return false
}

View File

@@ -0,0 +1,19 @@
//go:build: freebsd
package systemops
import "syscall"
// filterRoutesByFlags - return true if need to ignore such route message because it consists specific flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 (https://www.freebsd.org/releases/8.0R/relnotes-detailed/)
// a concept of cloned route (a route generated by an entry with RTF_CLONING flag) is deprecated.
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
return true
}
return false
}

View File

@@ -0,0 +1,27 @@
package systemops
import (
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
IP netip.Addr
Intf *net.Interface
}
type ExclusionCounter = refcounter.Counter[any, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
wgInterface *iface.WGIface
}
func NewSysOps(wgInterface *iface.WGIface) *SysOps {
return &SysOps{
wgInterface: wgInterface,
}
}

View File

@@ -0,0 +1,160 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"syscall"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
)
type Route struct {
Dst netip.Prefix
Gw netip.Addr
Interface *net.Interface
}
func getRoutesFromTable() ([]netip.Prefix, error) {
tab, err := retryFetchRIB()
if err != nil {
return nil, fmt.Errorf("fetch RIB: %v", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, tab)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
var prefixList []netip.Prefix
for _, msg := range msgs {
m := msg.(*route.RouteMessage)
if m.Version < 3 || m.Version > 5 {
return nil, fmt.Errorf("unexpected RIB message version: %d", m.Version)
}
if m.Type != syscall.RTM_GET {
return nil, fmt.Errorf("unexpected RIB message type: %d", m.Type)
}
if filterRoutesByFlags(m.Flags) {
continue
}
route, err := MsgToRoute(m)
if err != nil {
log.Warnf("Failed to parse route message: %v", err)
continue
}
if route.Dst.IsValid() {
prefixList = append(prefixList, route.Dst)
}
}
return prefixList, nil
}
func retryFetchRIB() ([]byte, error) {
var out []byte
operation := func() error {
var err error
out, err = route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
if errors.Is(err, syscall.ENOMEM) {
log.Debug("~etrying fetchRIB due to 'cannot allocate memory' error")
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return nil, fmt.Errorf("failed to fetch routing information: %w", err)
}
return out, nil
}
func toNetIP(a route.Addr) netip.Addr {
switch t := a.(type) {
case *route.Inet4Addr:
return netip.AddrFrom4(t.IP)
case *route.Inet6Addr:
ip := netip.AddrFrom16(t.IP)
if t.ZoneID != 0 {
ip.WithZone(strconv.Itoa(t.ZoneID))
}
return ip
default:
return netip.Addr{}
}
}
// ones returns the number of leading ones in the mask.
func ones(a route.Addr) (int, error) {
switch t := a.(type) {
case *route.Inet4Addr:
mask, _ := net.IPMask(t.IP[:]).Size()
return mask, nil
case *route.Inet6Addr:
mask, _ := net.IPMask(t.IP[:]).Size()
return mask, nil
default:
return 0, fmt.Errorf("unexpected address type: %T", a)
}
}
// MsgToRoute converts a route message to a Route.
func MsgToRoute(msg *route.RouteMessage) (*Route, error) {
dstIP, nexthop, dstMask := msg.Addrs[0], msg.Addrs[1], msg.Addrs[2]
addr := toNetIP(dstIP)
var nexthopAddr netip.Addr
var nexthopIntf *net.Interface
switch t := nexthop.(type) {
case *route.Inet4Addr, *route.Inet6Addr:
nexthopAddr = toNetIP(t)
case *route.LinkAddr:
nexthopIntf = &net.Interface{
Index: t.Index,
Name: t.Name,
}
default:
return nil, fmt.Errorf("unexpected next hop type: %T", t)
}
var prefix netip.Prefix
if dstMask == nil {
if addr.Is4() {
prefix = netip.PrefixFrom(addr, 32)
} else {
prefix = netip.PrefixFrom(addr, 128)
}
} else {
bits, err := ones(dstMask)
if err != nil {
return nil, fmt.Errorf("failed to parse mask: %v", dstMask)
}
prefix = netip.PrefixFrom(addr, bits)
}
return &Route{
Dst: prefix,
Gw: nexthopAddr,
Interface: nexthopIntf,
}, nil
}

View File

@@ -0,0 +1,57 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"testing"
"github.com/stretchr/testify/assert"
"golang.org/x/net/route"
)
func TestBits(t *testing.T) {
tests := []struct {
name string
addr route.Addr
want int
wantErr bool
}{
{
name: "IPv4 all ones",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 255}},
want: 32,
},
{
name: "IPv4 normal mask",
addr: &route.Inet4Addr{IP: [4]byte{255, 255, 255, 0}},
want: 24,
},
{
name: "IPv6 all ones",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}},
want: 128,
},
{
name: "IPv6 normal mask",
addr: &route.Inet6Addr{IP: [16]byte{255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0}},
want: 64,
},
{
name: "Unsupported type",
addr: &route.LinkAddr{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ones(tt.addr)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}

View File

@@ -0,0 +1,140 @@
//go:build !ios
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var expectedVPNint = "utun100"
var expectedExternalInt = "lo0"
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
destination: "10.10.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
intf := &net.Interface{Name: "lo0"}
r := NewSysOps(nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, Nexthop{netip.Addr{}, intf}); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return "lo0"
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -0,0 +1,416 @@
//go:build !android && !ios
package systemops
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
func (r *SysOps) setupRefCounter(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v4 default next hop: %v", err)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
log.Errorf("Unable to get initial v6 default next hop: %v", err)
}
refCounter := refcounter.New(
func(prefix netip.Prefix, _ any) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
}
nexthop, err := r.addRouteToNonVPNIntf(prefix, r.wgInterface, initialNexthop)
if errors.Is(err, vars.ErrRouteNotAllowed) || errors.Is(err, vars.ErrRouteNotFound) {
log.Tracef("Adding for prefix %s: %v", prefix, err)
// These errors are not critical but also we should not track and try to remove the routes either.
return nexthop, refcounter.ErrIgnore
}
return nexthop, err
},
r.removeFromRouteTable,
)
r.refCounter = refCounter
return r.setupHooks(initAddresses)
}
func (r *SysOps) cleanupRefCounter() error {
if r.refCounter == nil {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
}
return nil
}
// TODO: fix: for default our wg address now appears as the default gw
func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
addr := netip.IPv4Unspecified()
if prefix.Addr().Is6() {
addr = netip.IPv6Unspecified()
}
nexthop, err := GetNextHop(addr)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("get existing route gateway: %s", err)
}
if !prefix.Contains(nexthop.IP) {
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", nexthop.IP, prefix)
return nil
}
gatewayPrefix := netip.PrefixFrom(nexthop.IP, 32)
if nexthop.IP.Is6() {
gatewayPrefix = netip.PrefixFrom(nexthop.IP, 128)
}
ok, err := existsInRouteTable(gatewayPrefix)
if err != nil {
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
}
if ok {
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
return nil
}
nexthop, err = GetNextHop(nexthop.IP)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
}
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, nexthop.IP)
return r.addToRouteTable(gatewayPrefix, nexthop)
}
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
addr.IsLinkLocalUnicast(),
addr.IsLinkLocalMulticast(),
addr.IsInterfaceLocalMulticast(),
addr.IsUnspecified(),
addr.IsMulticast():
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route
nexthop, err := GetNextHop(addr)
if err != nil {
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
}
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
exitNextHop := Nexthop{
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
if !ok {
return Nexthop{}, fmt.Errorf("failed to convert vpn address to netip.Addr")
}
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
exitNextHop = initialNextHop
}
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop.IP)
if err := r.addToRouteTable(prefix, exitNextHop); err != nil {
return Nexthop{}, fmt.Errorf("add route to table: %w", err)
}
return exitNextHop, nil
}
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
// in two /1 prefixes to avoid replacing the existing default route
func (r *SysOps) genericAddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
if err := r.addToRouteTable(splitDefaultv4_1, nextHop); err != nil {
return err
}
if err := r.addToRouteTable(splitDefaultv4_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return err
}
// TODO: remove once IPv6 is supported on the interface
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
} else if prefix == vars.Defaultv6 {
if err := r.addToRouteTable(splitDefaultv6_1, nextHop); err != nil {
return fmt.Errorf("add unreachable route split 1: %w", err)
}
if err := r.addToRouteTable(splitDefaultv6_2, nextHop); err != nil {
if err2 := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err2 != nil {
log.Warnf("Failed to rollback route addition: %s", err2)
}
return fmt.Errorf("add unreachable route split 2: %w", err)
}
return nil
}
return r.addNonExistingRoute(prefix, intf)
}
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
func (r *SysOps) addNonExistingRoute(prefix netip.Prefix, intf *net.Interface) error {
ok, err := existsInRouteTable(prefix)
if err != nil {
return fmt.Errorf("exists in route table: %w", err)
}
if ok {
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
return nil
}
ok, err = isSubRange(prefix)
if err != nil {
return fmt.Errorf("sub range: %w", err)
}
if ok {
if err := r.addRouteForCurrentDefaultGateway(prefix); err != nil {
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
}
}
return r.addToRouteTable(prefix, Nexthop{netip.Addr{}, intf})
}
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
// it will remove the split /1 prefixes
func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
nextHop := Nexthop{netip.Addr{}, intf}
if prefix == vars.Defaultv4 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv4_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv4_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
// TODO: remove once IPv6 is supported on the interface
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
} else if prefix == vars.Defaultv6 {
var result *multierror.Error
if err := r.removeFromRouteTable(splitDefaultv6_1, nextHop); err != nil {
result = multierror.Append(result, err)
}
if err := r.removeFromRouteTable(splitDefaultv6_2, nextHop); err != nil {
result = multierror.Append(result, err)
}
return nberrors.FormatErrorOrNil(result)
}
return r.removeFromRouteTable(prefix, nextHop)
}
func (r *SysOps) setupHooks(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
return nil
}
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err)
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
var result *multierror.Error
for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(result)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
return beforeHook, afterHook, nil
}
func GetNextHop(ip netip.Addr) (Nexthop, error) {
r, err := netroute.New()
if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("Route for %s: interface %v nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
if gateway == nil {
if runtime.GOOS == "freebsd" {
return Nexthop{Intf: intf}, nil
}
if preferredSrc == nil {
return Nexthop{}, vars.ErrRouteNotFound
}
log.Debugf("No next hop found for IP %s, using preferred source %s", ip, preferredSrc)
addr, err := ipToAddr(preferredSrc, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert preferred source to address: %w", err)
}
return Nexthop{
IP: addr.Unmap(),
Intf: intf,
}, nil
}
addr, err := ipToAddr(gateway, intf)
if err != nil {
return Nexthop{}, fmt.Errorf("convert gateway to address: %w", err)
}
return Nexthop{
IP: addr,
Intf: intf,
}, nil
}
// converts a net.IP to a netip.Addr including the zone based on the passed interface
func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip)
}
if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) {
log.Tracef("Adding zone %s to address %s", intf.Name, addr)
if runtime.GOOS == "windows" {
addr = addr.WithZone(strconv.Itoa(intf.Index))
} else {
addr = addr.WithZone(intf.Name)
}
}
return addr.Unmap(), nil
}
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute == prefix {
return true, nil
}
}
return false, nil
}
func isSubRange(prefix netip.Prefix) (bool, error) {
routes, err := getRoutesFromTable()
if err != nil {
return false, fmt.Errorf("get routes from table: %w", err)
}
for _, tableRoute := range routes {
if tableRoute.Bits() > vars.MinRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
return true, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,412 @@
//go:build !android && !ios
package systemops
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strings"
"testing"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface"
)
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddRemoveRoutes(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
shouldRouteToWireguard bool
shouldBeRemoved bool
}{
{
name: "Should Add And Remove Route 100.66.120.0/24",
prefix: netip.MustParsePrefix("100.66.120.0/24"),
shouldRouteToWireguard: true,
shouldBeRemoved: true,
},
{
name: "Should Not Add Or Remove Route 127.0.0.1/32",
prefix: netip.MustParsePrefix("127.0.0.1/32"),
shouldRouteToWireguard: false,
shouldBeRemoved: false,
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
r := NewSysOps(wgInterface)
_, _, err = r.SetupRouting(nil)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericAddVPNRoute should not return err")
if testCase.shouldRouteToWireguard {
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
} else {
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
}
exists, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "existsInRouteTable should not return err")
if exists && testCase.shouldRouteToWireguard {
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
prefixNexthop, err := GetNextHop(testCase.prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
internetNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
require.NoError(t, err)
if testCase.shouldBeRemoved {
require.Equal(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to default internet gateway")
} else {
require.NotEqual(t, internetNexthop.IP, prefixNexthop.IP, "route should be pointing to a different gateway than the internet gateway")
}
}
})
}
}
func TestGetNextHop(t *testing.T) {
nexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
if !nexthop.IP.IsValid() {
t.Fatal("should return a gateway")
}
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var testingIP string
var testingPrefix netip.Prefix
for _, address := range addresses {
if address.Network() != "ip+net" {
continue
}
prefix := netip.MustParsePrefix(address.String())
if !prefix.Addr().IsLoopback() && prefix.Addr().Is4() {
testingIP = prefix.Addr().String()
testingPrefix = prefix.Masked()
break
}
}
localIP, err := GetNextHop(testingPrefix.Addr())
if err != nil {
t.Fatal("shouldn't return error: ", err)
}
if !localIP.IP.IsValid() {
t.Fatal("should return a gateway for local network")
}
if localIP.IP.String() == nexthop.IP.String() {
t.Fatal("local IP should not match with gateway IP")
}
if localIP.IP.String() != testingIP {
t.Fatalf("local IP should match with testing IP: want %s got %s", testingIP, localIP.IP.String())
}
}
func TestAddExistAndRemoveRoute(t *testing.T) {
defaultNexthop, err := GetNextHop(netip.MustParseAddr("0.0.0.0"))
t.Log("defaultNexthop: ", defaultNexthop)
if err != nil {
t.Fatal("shouldn't return error when fetching the gateway: ", err)
}
testCases := []struct {
name string
prefix netip.Prefix
preExistingPrefix netip.Prefix
shouldAddRoute bool
}{
{
name: "Should Add And Remove random Route",
prefix: netip.MustParsePrefix("99.99.99.99/32"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if overlaps with default gateway",
prefix: netip.MustParsePrefix(defaultNexthop.IP.String() + "/31"),
shouldAddRoute: false,
},
{
name: "Should Add Route if bigger network exists",
prefix: netip.MustParsePrefix("100.100.100.0/24"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: true,
},
{
name: "Should Add Route if smaller network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.100.0/24"),
shouldAddRoute: true,
},
{
name: "Should Not Add Route if same network exists",
prefix: netip.MustParsePrefix("100.100.0.0/16"),
preExistingPrefix: netip.MustParsePrefix("100.100.0.0/16"),
shouldAddRoute: false,
},
}
for n, testCase := range testCases {
var buf bytes.Buffer
log.SetOutput(&buf)
defer func() {
log.SetOutput(os.Stderr)
}()
t.Run(testCase.name, func(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_DISABLE_ROUTE_CACHE", "true")
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
if err != nil {
t.Fatal(err)
}
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close()
err = wgInterface.Create()
require.NoError(t, err, "should create testing wireguard interface")
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
r := NewSysOps(wgInterface)
// Prepare the environment
if testCase.preExistingPrefix.IsValid() {
err := r.AddVPNRoute(testCase.preExistingPrefix, intf)
require.NoError(t, err, "should not return err when adding pre-existing route")
}
// Add the route
err = r.AddVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err when adding route")
if testCase.shouldAddRoute {
// test if route exists after adding
ok, err := existsInRouteTable(testCase.prefix)
require.NoError(t, err, "should not return err")
require.True(t, ok, "route should exist")
// remove route again if added
err = r.RemoveVPNRoute(testCase.prefix, intf)
require.NoError(t, err, "should not return err")
}
// route should either not have been added or should have been removed
// In case of already existing route, it should not have been added (but still exist)
ok, err := existsInRouteTable(testCase.prefix)
t.Log("Buffer string: ", buf.String())
require.NoError(t, err, "should not return err")
if !strings.Contains(buf.String(), "because it already exists") {
require.False(t, ok, "route should not exist")
}
})
}
}
func TestIsSubRange(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var subRangeAddressPrefixes []netip.Prefix
var nonSubRangeAddressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 {
p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1)
subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2)
nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked())
}
}
for _, prefix := range subRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if !isSubRangePrefix {
t.Fatalf("address %s should be sub-range of an existing route in the table", prefix)
}
}
for _, prefix := range nonSubRangeAddressPrefixes {
isSubRangePrefix, err := isSubRange(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address is sub-range: ", err)
}
if isSubRangePrefix {
t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix)
}
}
}
func TestExistsInRouteTable(t *testing.T) {
addresses, err := net.InterfaceAddrs()
if err != nil {
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
}
var addressPrefixes []netip.Prefix
for _, address := range addresses {
p := netip.MustParsePrefix(address.String())
if p.Addr().Is6() {
continue
}
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
continue
}
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
continue
}
addressPrefixes = append(addressPrefixes, p.Masked())
}
for _, prefix := range addressPrefixes {
exists, err := existsInRouteTable(prefix)
if err != nil {
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
}
if !exists {
t.Fatalf("address %s should exist in route table", prefix)
}
}
}
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
t.Helper()
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet()
require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create()
require.NoError(t, err, "should create testing WireGuard interface")
t.Cleanup(func() {
wgInterface.Close()
})
return wgInterface
}
func setupRouteAndCleanup(t *testing.T, r *SysOps, prefix netip.Prefix, intf *net.Interface) {
t.Helper()
err := r.AddVPNRoute(prefix, intf)
require.NoError(t, err, "addVPNRoute should not return err")
t.Cleanup(func() {
err = r.RemoveVPNRoute(prefix, intf)
assert.NoError(t, err, "removeVPNRoute should not return err")
})
}
func setupTestEnv(t *testing.T) {
t.Helper()
setupDummyInterfacesAndRoutes(t)
wgInterface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
t.Cleanup(func() {
assert.NoError(t, wgInterface.Close())
})
r := NewSysOps(wgInterface)
_, _, err := r.SetupRouting(nil)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting())
})
index, err := net.InterfaceByName(wgInterface.Name())
require.NoError(t, err, "InterfaceByName should not return err")
intf := &net.Interface{Index: index.Index, Name: wgInterface.Name()}
// default route exists in main table and vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("0.0.0.0/0"), intf)
// 10.0.0.0/8 route exists in main table and vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.0.0.0/8"), intf)
// 10.10.0.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("10.10.0.0/24"), intf)
// 127.0.10.0/24 more specific route exists in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("127.0.10.0/24"), intf)
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
t.Helper()
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
return
}
prefixNexthop, err := GetNextHop(prefix.Addr())
require.NoError(t, err, "GetNextHop should not return err")
if invert {
assert.NotEqual(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should not point to wireguard interface IP")
} else {
assert.Equal(t, wgIface.Address().IP.String(), prefixNexthop.IP.String(), "route should point to wireguard interface IP")
}
}

View File

@@ -0,0 +1,504 @@
//go:build !android
package systemops
import (
"bufio"
"errors"
"fmt"
"net"
"net/netip"
"os"
"syscall"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
// NetbirdVPNTableID is the ID of the custom routing table used by Netbird.
NetbirdVPNTableID = 0x1BD0
// NetbirdVPNTableName is the name of the custom routing table used by Netbird.
NetbirdVPNTableName = "netbird"
// rtTablesPath is the path to the file containing the routing table names.
rtTablesPath = "/etc/iproute2/rt_tables"
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
ipv4ForwardingPath = "net.ipv4.ip_forward"
)
var ErrTableIDExists = errors.New("ID exists with different name")
// originalSysctl stores the original sysctl values before they are modified
var originalSysctl map[string]int
// sysctlFailed is used as an indicator to emit a warning when default routes are configured
var sysctlFailed bool
type ruleParams struct {
priority int
fwmark int
tableID int
family int
invert bool
suppressPrefix int
description string
}
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
}
// setIsLegacy sets the legacy routing setup
func setIsLegacy(b bool) {
if b {
os.Setenv("NB_USE_LEGACY_ROUTING", "true")
} else {
os.Unsetenv("NB_USE_LEGACY_ROUTING")
}
}
func getSetupRules() []ruleParams {
return []ruleParams{
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, false, 0, "rule with suppress prefixlen v4"},
{100, -1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, false, 0, "rule with suppress prefixlen v6"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, true, -1, "rule v4 netbird"},
{110, nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, true, -1, "rule v6 netbird"},
}
}
// SetupRouting establishes the routing configuration for the VPN, including essential rules
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
//
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
// potential routes received and configured for the VPN. This rule is skipped for the default route and routes
// that are not in the main table.
//
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
if isLegacy() {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses)
}
if err = addRoutingTableName(); err != nil {
log.Errorf("Error adding routing table name: %v", err)
}
originalValues, err := sysctl.Setup(r.wgInterface)
if err != nil {
log.Errorf("Error setting up sysctl: %v", err)
sysctlFailed = true
}
originalSysctl = originalValues
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
}()
rules := getSetupRules()
for _, rule := range rules {
if err := addRule(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
setIsLegacy(true)
return r.setupRefCounter(initAddresses)
}
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
}
}
return nil, nil, nil
}
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting() error {
if isLegacy() {
return r.cleanupRefCounter()
}
var result *multierror.Error
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err))
}
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err))
}
rules := getSetupRules()
for _, rule := range rules {
if err := removeRule(rule); err != nil {
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
}
}
if err := sysctl.Cleanup(originalSysctl); err != nil {
result = multierror.Append(result, fmt.Errorf("cleanup sysctl: %w", err))
}
originalSysctl = nil
sysctlFailed = false
return nberrors.FormatErrorOrNil(result)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return addRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return removeRoute(prefix, nexthop, syscall.RT_TABLE_MAIN)
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return r.genericAddVPNRoute(prefix, intf)
}
if sysctlFailed && (prefix == vars.Defaultv4 || prefix == vars.Defaultv6) {
log.Warnf("Default route is configured but sysctl operations failed, VPN traffic may not be routed correctly, consider using NB_USE_LEGACY_ROUTING=true or setting net.ipv4.conf.*.rp_filter to 2 (loose) or 0 (off)")
}
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
// TODO remove this once we have ipv6 support
if prefix == vars.Defaultv4 {
if err := addUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add blackhole: %w", err)
}
}
if err := addRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
return fmt.Errorf("add route: %w", err)
}
return nil
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
if isLegacy() {
return r.genericRemoveVPNRoute(prefix, intf)
}
// TODO remove this once we have ipv6 support
if prefix == vars.Defaultv4 {
if err := removeUnreachableRoute(vars.Defaultv6, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove unreachable route: %w", err)
}
}
if err := removeRoute(prefix, Nexthop{netip.Addr{}, intf}, NetbirdVPNTableID); err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
if err != nil {
return nil, fmt.Errorf("get v4 routes: %w", err)
}
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
if err != nil {
return nil, fmt.Errorf("get v6 routes: %w", err)
}
return append(v4Routes, v6Routes...), nil
}
// getRoutes fetches routes from a specific routing table identified by tableID.
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
var prefixList []netip.Prefix
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
}
for _, route := range routes {
if route.Dst != nil {
addr, ok := netip.AddrFromSlice(route.Dst.IP)
if !ok {
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
}
ones, _ := route.Dst.Mask.Size()
prefix := netip.PrefixFrom(addr, ones)
if prefix.IsValid() {
prefixList = append(prefixList, prefix)
}
}
}
return prefixList, nil
}
// addRoute adds a route to a specific routing table identified by tableID.
func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: getAddressFamily(prefix),
}
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route.Dst = ipNet
if err := addNextHop(nexthop, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink add route: %w", err)
}
return nil
}
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
// tableID specifies the routing table to which the unreachable route will be added.
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}
return nil
}
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Type: syscall.RTN_UNREACHABLE,
Table: tableID,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}
return nil
}
// removeRoute removes a route from a specific routing table identified by tableID.
func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
_, ipNet, err := net.ParseCIDR(prefix.String())
if err != nil {
return fmt.Errorf("parse prefix %s: %w", prefix, err)
}
route := &netlink.Route{
Scope: netlink.SCOPE_UNIVERSE,
Table: tableID,
Family: getAddressFamily(prefix),
Dst: ipNet,
}
if err := addNextHop(nexthop, route); err != nil {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove route: %w", err)
}
return nil
}
func flushRoutes(tableID, family int) error {
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
if err != nil {
return fmt.Errorf("list routes from table %d: %w", tableID, err)
}
var result *multierror.Error
for i := range routes {
route := routes[i]
// unreachable default routes don't come back with Dst set
if route.Gw == nil && route.Src == nil && route.Dst == nil {
if family == netlink.FAMILY_V4 {
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
} else {
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
}
}
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
}
}
return nberrors.FormatErrorOrNil(result)
}
func EnableIPForwarding() error {
_, err := sysctl.Set(ipv4ForwardingPath, 1, false)
return err
}
// entryExists checks if the specified ID or name already exists in the rt_tables file
// and verifies if existing names start with "netbird_".
func entryExists(file *os.File, id int) (bool, error) {
if _, err := file.Seek(0, 0); err != nil {
return false, fmt.Errorf("seek rt_tables: %w", err)
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
var existingID int
var existingName string
if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil {
if existingID == id {
if existingName != NetbirdVPNTableName {
return true, ErrTableIDExists
}
return true, nil
}
}
}
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("scan rt_tables: %w", err)
}
return false, nil
}
// addRoutingTableName adds human-readable names for custom routing tables.
func addRoutingTableName() error {
file, err := os.Open(rtTablesPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return fmt.Errorf("open rt_tables: %w", err)
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("Error closing rt_tables: %v", err)
}
}()
exists, err := entryExists(file, NetbirdVPNTableID)
if err != nil {
return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err)
}
if exists {
return nil
}
// Reopen the file in append mode to add new entries
if err := file.Close(); err != nil {
log.Errorf("Error closing rt_tables before appending: %v", err)
}
file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
if err != nil {
return fmt.Errorf("open rt_tables for appending: %w", err)
}
if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil {
return fmt.Errorf("append entry to rt_tables: %w", err)
}
return nil
}
// addRule adds a routing rule to a specific routing table identified by tableID.
func addRule(params ruleParams) error {
rule := netlink.NewRule()
rule.Table = params.tableID
rule.Mark = params.fwmark
rule.Family = params.family
rule.Priority = params.priority
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("add routing rule: %w", err)
}
return nil
}
// removeRule removes a routing rule from a specific routing table identified by tableID.
func removeRule(params ruleParams) error {
rule := netlink.NewRule()
rule.Table = params.tableID
rule.Mark = params.fwmark
rule.Family = params.family
rule.Invert = params.invert
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("remove routing rule: %w", err)
}
return nil
}
// addNextHop adds the gateway and device to the route.
func addNextHop(nexthop Nexthop, route *netlink.Route) error {
if nexthop.Intf != nil {
route.LinkIndex = nexthop.Intf.Index
}
if nexthop.IP.IsValid() {
route.Gw = nexthop.IP.AsSlice()
// if zone is set, it means the gateway is a link-local address, so we set the link index
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
link, err := netlink.LinkByName(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("get link by name for zone %s: %w", nexthop.IP.Zone(), err)
}
route.LinkIndex = link.Attrs().Index
}
}
return nil
}
func getAddressFamily(prefix netip.Prefix) int {
if prefix.Addr().Is4() {
return netlink.FAMILY_V4
}
return netlink.FAMILY_V6
}

View File

@@ -0,0 +1,209 @@
//go:build !android
package systemops
import (
"errors"
"fmt"
"net"
"os"
"strings"
"syscall"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
var expectedVPNint = "wgtest0"
var expectedLoopbackInt = "lo"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via physical interface",
destination: "10.10.0.2:53",
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.1:53",
expectedInterface: expectedLoopbackInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
},
}...)
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
err := netlink.LinkDel(dummy)
if err != nil && !errors.Is(err, syscall.EINVAL) {
t.Logf("Failed to delete dummy interface: %v", err)
}
err = netlink.LinkAdd(dummy)
require.NoError(t, err)
err = netlink.LinkSetUp(dummy)
require.NoError(t, err)
if ipAddressCIDR != "" {
addr, err := netlink.ParseAddr(ipAddressCIDR)
require.NoError(t, err)
err = netlink.AddrAdd(dummy, addr)
require.NoError(t, err)
}
t.Cleanup(func() {
err := netlink.LinkDel(dummy)
assert.NoError(t, err)
})
return dummy.Name
}
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
t.Helper()
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
require.NoError(t, err)
// Handle existing routes with metric 0
var originalNexthop net.IP
var originalLinkIndex int
if dstIPNet.String() == "0.0.0.0/0" {
var err error
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {
t.Logf("Failed to fetch original gateway: %v", err)
}
if originalNexthop != nil {
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
switch {
case err != nil && !errors.Is(err, syscall.ESRCH):
t.Logf("Failed to delete route: %v", err)
case err == nil:
t.Cleanup(func() {
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
})
default:
t.Logf("Failed to delete route: %v", err)
}
}
}
link, err := netlink.LinkByName(intf)
require.NoError(t, err)
linkIndex := link.Attrs().Index
route := &netlink.Route{
Dst: dstIPNet,
Gw: gw,
LinkIndex: linkIndex,
}
err = netlink.RouteDel(route)
if err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("Failed to delete route: %v", err)
}
err = netlink.RouteAdd(route)
if err != nil && !errors.Is(err, syscall.EEXIST) {
t.Fatalf("Failed to add route: %v", err)
}
require.NoError(t, err)
}
func fetchOriginalGateway(family int) (net.IP, int, error) {
routes, err := netlink.RouteList(nil, family)
if err != nil {
return nil, 0, err
}
for _, route := range routes {
if route.Dst == nil && route.Priority == 0 {
return route.Gw, route.LinkIndex, nil
}
}
return nil, 0, vars.ErrRouteNotFound
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
}

View File

@@ -0,0 +1,34 @@
//go:build ios || android
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
)
func (r *SysOps) SetupRouting([]net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return nil, nil, nil
}
func (r *SysOps) CleanupRouting() error {
return nil
}
func (r *SysOps) AddVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error {
return nil
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}

View File

@@ -0,0 +1,24 @@
//go:build !linux && !ios
package systemops
import (
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
)
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericAddVPNRoute(prefix, intf)
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return r.genericRemoveVPNRoute(prefix, intf)
}
func EnableIPForwarding() error {
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
return nil
}

View File

@@ -0,0 +1,82 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
)
func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return r.setupRefCounter(initAddresses)
}
func (r *SysOps) CleanupRouting() error {
return r.cleanupRefCounter()
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("add", prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeCmd("delete", prefix, nexthop)
}
func (r *SysOps) routeCmd(action string, prefix netip.Prefix, nexthop Nexthop) error {
inet := "-inet"
if prefix.Addr().Is6() {
inet = "-inet6"
}
network := prefix.String()
if prefix.IsSingleIP() {
network = prefix.Addr().String()
}
args := []string{"-n", action, inet, network}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else if nexthop.Intf != nil {
args = append(args, "-interface", nexthop.Intf.Name)
}
if err := retryRouteCmd(args); err != nil {
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
}
return nil
}
func retryRouteCmd(args []string) error {
operation := func() error {
out, err := exec.Command("route", args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
// https://github.com/golang/go/issues/45736
if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") {
return err
} else if err != nil {
return backoff.Permanent(err)
}
return nil
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
err := backoff.Retry(operation, expBackOff)
if err != nil {
return fmt.Errorf("route cmd retry failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,234 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
import (
"fmt"
"net"
"strings"
"testing"
"time"
"github.com/gopacket/gopacket"
"github.com/gopacket/gopacket/layers"
"github.com/gopacket/gopacket/pcap"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
type testCase struct {
name string
destination string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
destination: "192.0.2.1:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
destination: "172.16.0.2:53",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
func TestRouting(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
filter := createBPFFilter(tc.destination)
handle := startPacketCapture(t, tc.expectedInterface, filter)
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
packet, err := packetSource.NextPacket()
require.NoError(t, err)
verifyPacket(t, packet, tc.expectedPacket)
})
}
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()
inactive, err := pcap.NewInactiveHandle(intf)
require.NoError(t, err, "Failed to create inactive pcap handle")
defer inactive.CleanUp()
err = inactive.SetSnapLen(1600)
require.NoError(t, err, "Failed to set snap length on inactive handle")
err = inactive.SetTimeout(time.Second * 10)
require.NoError(t, err, "Failed to set timeout on inactive handle")
err = inactive.SetImmediateMode(true)
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
handle, err := inactive.Activate()
require.NoError(t, err, "Failed to activate pcap handle")
t.Cleanup(handle.Close)
err = handle.SetBPFFilter(filter)
require.NoError(t, err, "Failed to set BPF filter")
return handle
}
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
t.Helper()
if dialer == nil {
dialer = &net.Dialer{}
}
if sourcePort != 0 {
localUDPAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: sourcePort,
}
switch dialer := dialer.(type) {
case *nbnet.Dialer:
dialer.LocalAddr = localUDPAddr
case *net.Dialer:
dialer.LocalAddr = localUDPAddr
default:
t.Fatal("Unsupported dialer type")
}
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = []dns.Question{
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
conn, err := dialer.Dial("udp", destination)
require.NoError(t, err, "Failed to dial UDP")
defer conn.Close()
data, err := msg.Pack()
require.NoError(t, err, "Failed to pack DNS message")
_, err = conn.Write(data)
if err != nil {
if strings.Contains(err.Error(), "required key not available") {
t.Logf("Ignoring WireGuard key error: %v", err)
return
}
t.Fatalf("Failed to send DNS query: %v", err)
}
}
func createBPFFilter(destination string) string {
host, port, err := net.SplitHostPort(destination)
if err != nil {
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
}
return "udp"
}
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
t.Helper()
ipLayer := packet.Layer(layers.LayerTypeIPv4)
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
ip, ok := ipLayer.(*layers.IPv4)
require.True(t, ok, "Failed to cast to IPv4 layer")
// Convert both source and destination IP addresses to 16-byte representation
expectedSrcIP := exp.SrcIP.To16()
actualSrcIP := ip.SrcIP.To16()
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
expectedDstIP := exp.DstIP.To16()
actualDstIP := ip.DstIP.To16()
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
if exp.UDP {
udpLayer := packet.Layer(layers.LayerTypeUDP)
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
udp, ok := udpLayer.(*layers.UDP)
require.True(t, ok, "Failed to cast to UDP layer")
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
}
if exp.TCP {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
tcp, ok := tcpLayer.(*layers.TCP)
require.True(t, ok, "Failed to cast to TCP layer")
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
}
}

View File

@@ -0,0 +1,218 @@
//go:build windows
package systemops
import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/peer"
)
type MSFT_NetRoute struct {
DestinationPrefix string
NextHop string
InterfaceIndex int32
InterfaceAlias string
AddressFamily uint16
}
type Route struct {
Destination netip.Prefix
Nexthop netip.Addr
Interface *net.Interface
}
type MSFT_NetNeighbor struct {
IPAddress string
LinkLayerAddress string
State uint8
AddressFamily uint16
InterfaceIndex uint32
InterfaceAlias string
}
type Neighbor struct {
IPAddress netip.Addr
LinkLayerAddress string
State uint8
AddressFamily uint16
InterfaceIndex uint32
InterfaceAlias string
}
var prefixList []netip.Prefix
var lastUpdate time.Time
var mux = sync.Mutex{}
func (r *SysOps) SetupRouting(initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
return r.setupRefCounter(initAddresses)
}
func (r *SysOps) CleanupRouting() error {
return r.cleanupRefCounter()
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
if nexthop.IP.Zone() != "" && nexthop.Intf == nil {
zone, err := strconv.Atoi(nexthop.IP.Zone())
if err != nil {
return fmt.Errorf("invalid zone: %w", err)
}
nexthop.Intf = &net.Interface{Index: zone}
nexthop.IP.WithZone("")
}
return addRouteCmd(prefix, nexthop)
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"delete", prefix.String()}
if nexthop.IP.IsValid() {
nexthop.IP.WithZone("")
args = append(args, nexthop.IP.Unmap().String())
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("remove route: %w", err)
}
return nil
}
func getRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
// If many routes are added at the same time this might block for a long time (seconds to minutes), so we cache the result
if !isCacheDisabled() && time.Since(lastUpdate) < 2*time.Second {
return prefixList, nil
}
routes, err := GetRoutes()
if err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
prefixList = nil
for _, route := range routes {
prefixList = append(prefixList, route.Destination)
}
lastUpdate = time.Now()
return prefixList, nil
}
func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute
query := `SELECT DestinationPrefix, Nexthop, InterfaceIndex, InterfaceAlias, AddressFamily FROM MSFT_NetRoute`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("get routes: %w", err)
}
var routes []Route
for _, entry := range entries {
dest, err := netip.ParsePrefix(entry.DestinationPrefix)
if err != nil {
log.Warnf("Unable to parse route destination %s: %v", entry.DestinationPrefix, err)
continue
}
nexthop, err := netip.ParseAddr(entry.NextHop)
if err != nil {
log.Warnf("Unable to parse route next hop %s: %v", entry.NextHop, err)
continue
}
var intf *net.Interface
if entry.InterfaceIndex != 0 {
intf = &net.Interface{
Index: int(entry.InterfaceIndex),
Name: entry.InterfaceAlias,
}
}
routes = append(routes, Route{
Destination: dest,
Nexthop: nexthop,
Interface: intf,
})
}
return routes, nil
}
func GetNeighbors() ([]Neighbor, error) {
var entries []MSFT_NetNeighbor
query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor`
if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err)
}
var neighbors []Neighbor
for _, entry := range entries {
addr, err := netip.ParseAddr(entry.IPAddress)
if err != nil {
log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err)
continue
}
neighbors = append(neighbors, Neighbor{
IPAddress: addr,
LinkLayerAddress: entry.LinkLayerAddress,
State: entry.State,
AddressFamily: entry.AddressFamily,
InterfaceIndex: entry.InterfaceIndex,
InterfaceAlias: entry.InterfaceAlias,
})
}
return neighbors, nil
}
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
if nexthop.IP.IsValid() {
args = append(args, nexthop.IP.Unmap().String())
} else {
addr := "0.0.0.0"
if prefix.Addr().Is6() {
addr = "::"
}
args = append(args, addr)
}
if nexthop.Intf != nil {
args = append(args, "if", strconv.Itoa(nexthop.Intf.Index))
}
routeCmd := uspfilter.GetSystem32Command("route")
out, err := exec.Command(routeCmd, args...).CombinedOutput()
log.Tracef("route %s: %s", strings.Join(args, " "), out)
if err != nil {
return fmt.Errorf("route add: %w", err)
}
return nil
}
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}

View File

@@ -0,0 +1,289 @@
package systemops
import (
"context"
"encoding/json"
"fmt"
"net"
"os/exec"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
)
var expectedExtInt = "Ethernet1"
type RouteInfo struct {
NextHop string `json:"nexthop"`
InterfaceAlias string `json:"interfacealias"`
RouteMetric int `json:"routemetric"`
}
type FindNetRouteOutput struct {
IPAddress string `json:"IPAddress"`
InterfaceIndex int `json:"InterfaceIndex"`
InterfaceAlias string `json:"InterfaceAlias"`
AddressFamily int `json:"AddressFamily"`
NextHop string `json:"Nexthop"`
DestinationPrefix string `json:"DestinationPrefix"`
}
type testCase struct {
name string
destination string
expectedSourceIP string
expectedDestPrefix string
expectedNextHop string
expectedInterface string
dialer dialer
}
var expectedVPNint = "wgtest0"
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
destination: "192.0.2.1:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "128.0.0.0/1",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To external host with custom dialer via physical interface",
destination: "192.0.2.1:53",
expectedDestPrefix: "192.0.2.1/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
destination: "10.0.0.2:53",
expectedDestPrefix: "10.0.0.2/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
destination: "10.0.0.2:53",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "10.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
{
name: "To unique vpn route with custom dialer via physical interface",
destination: "172.16.0.2:53",
expectedDestPrefix: "172.16.0.2/32",
expectedInterface: expectedExtInt,
dialer: nbnet.NewDialer(),
},
{
name: "To unique vpn route without custom dialer via vpn",
destination: "172.16.0.2:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "172.16.0.0/12",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route without custom dialer via vpn interface",
destination: "10.10.0.2:53",
expectedSourceIP: "100.64.0.1",
expectedDestPrefix: "10.10.0.0/24",
expectedNextHop: "0.0.0.0",
expectedInterface: "wgtest0",
dialer: &net.Dialer{},
},
{
name: "To more specific route (local) without custom dialer via physical interface",
destination: "127.0.10.2:53",
expectedSourceIP: "10.0.0.1",
expectedDestPrefix: "127.0.0.0/8",
expectedNextHop: "0.0.0.0",
expectedInterface: "Loopback Pseudo-Interface 1",
dialer: &net.Dialer{},
},
}
func TestRouting(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setupTestEnv(t)
route, err := fetchOriginalGateway()
require.NoError(t, err, "Failed to fetch original gateway")
ip, err := fetchInterfaceIP(route.InterfaceAlias)
require.NoError(t, err, "Failed to fetch interface IP")
output := testRoute(t, tc.destination, tc.dialer)
if tc.expectedInterface == expectedExtInt {
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
} else {
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
}
})
}
}
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
func fetchInterfaceIP(interfaceAlias string) (string, error) {
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
out, err := exec.Command("powershell", "-Command", script).Output()
if err != nil {
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
}
ip := strings.TrimSpace(string(out))
return ip, nil
}
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := dialer.DialContext(ctx, "udp", destination)
require.NoError(t, err, "Failed to dial destination")
defer func() {
err := conn.Close()
assert.NoError(t, err, "Failed to close connection")
}()
host, _, err := net.SplitHostPort(destination)
require.NoError(t, err)
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, Nexthop, DestinationPrefix | ConvertTo-Json`, host)
out, err := exec.Command("powershell", "-Command", script).Output()
require.NoError(t, err, "Failed to execute Find-NetRoute")
var outputs []FindNetRouteOutput
err = json.Unmarshal(out, &outputs)
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
require.Greater(t, len(outputs), 0, "No route found for destination")
combinedOutput := combineOutputs(outputs)
return combinedOutput
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
require.NoError(t, err)
subnetMaskSize, _ := ipNet.Mask.Size()
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
// Wait for the IP address to be applied
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err = waitForIPAddress(ctx, interfaceName, ip.String())
require.NoError(t, err, "IP address not applied within timeout")
t.Cleanup(func() {
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
})
return interfaceName
}
func fetchOriginalGateway() (*RouteInfo, error) {
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object Nexthop, RouteMetric, InterfaceAlias | ConvertTo-Json")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
}
var routeInfo RouteInfo
err = json.Unmarshal(output, &routeInfo)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
}
return &routeInfo, nil
}
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
t.Helper()
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
}
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
if err != nil {
return err
}
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
for _, ip := range ipAddresses {
if strings.TrimSpace(ip) == expectedIPAddress {
return nil
}
}
}
}
}
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
var combined FindNetRouteOutput
for _, output := range outputs {
if output.IPAddress != "" {
combined.IPAddress = output.IPAddress
}
if output.InterfaceIndex != 0 {
combined.InterfaceIndex = output.InterfaceIndex
}
if output.InterfaceAlias != "" {
combined.InterfaceAlias = output.InterfaceAlias
}
if output.AddressFamily != 0 {
combined.AddressFamily = output.AddressFamily
}
if output.NextHop != "" {
combined.NextHop = output.NextHop
}
if output.DestinationPrefix != "" {
combined.DestinationPrefix = output.DestinationPrefix
}
}
return &combined
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
}