[client] Use native windows sock opts to avoid routing loops (#4314)

- Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed
 - Add methods to return the next best interface after the NetBird interface.
- Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method.
- Some refactoring to avoid import cycles
- Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var
This commit is contained in:
Viktor Liu
2025-09-20 09:31:04 +02:00
committed by GitHub
parent 90577682e4
commit 55126f990c
77 changed files with 1180 additions and 606 deletions

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/net"
)
// ConnectionListener export internal Listener for mobile

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// constants needed to manage and create iptable rules

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func isIptablesSupported() bool {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

95
client/grpc/dialer.go Normal file
View File

@@ -0,0 +1,95 @@
package grpc
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os/user"
"runtime"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/util/embeddedroots"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
b.Clock = backoff.SystemClock
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool,
}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
}
return conn, nil
}

View File

@@ -3,7 +3,7 @@ package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -17,7 +17,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type RecvMessage struct {

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -409,7 +409,7 @@ func toBytes(s string) (int64, error) {
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark
}
return 0

View File

@@ -15,8 +15,8 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -101,13 +101,8 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: udpConn,
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,

View File

@@ -12,7 +12,7 @@ import (
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type TunNetstackDevice struct {

View File

@@ -3,7 +3,7 @@
package udpmux
import (
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {

View File

@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -34,7 +34,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type ServiceViaMemory struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type upstreamResolver struct {

View File

@@ -446,6 +446,8 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)

View File

@@ -14,7 +14,7 @@ import (
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const defaultChannelSize = 100

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// ProbeResult holds the info about the result of a relay probe request

View File

@@ -36,9 +36,9 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
}
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
return nil
}
if err := m.sysOps.CleanupRouting(nil); err != nil {
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
ips := resolveURLsToIPs(initialAddresses)
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
if runtime.GOOS == "windows" {
nbnet.SetVPNInterfaceName("")
}
}
m.mux.Lock()

View File

@@ -12,11 +12,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
return nil
}

View File

@@ -3,7 +3,6 @@
package systemops
import (
"context"
"errors"
"fmt"
"net"
@@ -22,7 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
const localSubnetsCacheTTL = 15 * time.Minute
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
hooks.RemoveWriteHooks()
hooks.RemoveCloseHooks()
hooks.RemoveAddressRemoveHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
afterHook := func(connID hooks.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
var merr *multierror.Error
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
continue
}
if err := beforeHook("init", prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
hooks.AddWriteHook(beforeHook)
hooks.AddCloseHook(afterHook)
var merr *multierror.Error
for _, ip := range resolvedIPs {
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(merr)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
intf, err := net.InterfaceByName(wgInterface.Name())
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
})
r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
index, err := net.InterfaceByName(wgInterface.Name())

View File

@@ -12,14 +12,14 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
r.mu.Lock()
defer r.mu.Unlock()

View File

@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// IPRule contains IP rule information for debugging
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
if !advancedRouting {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager)
}
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
if !nbnet.AdvancedRouting() {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if !advancedRouting {
return r.cleanupRefCounter(stateManager)
}

View File

@@ -20,11 +20,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
return r.cleanupRefCounter(stateManager)
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type PacketExpectation struct {

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"os"
"runtime/debug"
"sort"
"strconv"
"sync"
"syscall"
@@ -19,9 +20,16 @@ import (
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const InfiniteLifetime = 0xffffffff
func init() {
nbnet.GetBestInterfaceFunc = GetBestInterface
}
const (
InfiniteLifetime = 0xffffffff
)
type RouteUpdateType int
@@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct {
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
}
// candidateRoute represents a potential route for selection during route lookup
type candidateRoute struct {
interfaceIndex uint32
prefixLength uint8
routeMetric uint32
interfaceMetric int
}
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
type IP_ADDRESS_PREFIX struct {
Prefix SOCKADDR_INET
@@ -177,11 +193,20 @@ const (
RouteDeleted
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return nil
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return nil
}
return r.cleanupRefCounter(stateManager)
}
@@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
if table != nil {
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
if ret != 0 {
log.Warnf("FreeMibTable failed with return code: %d", ret)
}
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
}
}
@@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
detailed := buildWindowsDetailedRoute(entry)
if detailed != nil {
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
detailedRoutes = append(detailedRoutes, *detailed)
}
}
@@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
return ip
}
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
var candidates []candidateRoute
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
for i := uint32(0); i < table.NumEntries; i++ {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
candidates = append(candidates, *candidate)
}
}
return candidates
}
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
// Returns nil if the route doesn't match the destination or should be skipped
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
return nil
}
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
return nil
}
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
return &candidateRoute{
interfaceIndex: entry.InterfaceIndex,
prefixLength: entry.DestinationPrefix.PrefixLength,
routeMetric: entry.Metric,
interfaceMetric: interfaceMetric,
}
}
// getInterfaceMetric retrieves the interface metric for a given interface and address family
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
if interfaceIndex == 0 {
@@ -821,6 +882,75 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
return int(ipInterfaceRow.Metric)
}
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
func sortRouteCandidates(candidates []candidateRoute) {
sort.Slice(candidates, func(i, j int) bool {
if candidates[i].prefixLength != candidates[j].prefixLength {
return candidates[i].prefixLength > candidates[j].prefixLength
}
if candidates[i].routeMetric != candidates[j].routeMetric {
return candidates[i].routeMetric < candidates[j].routeMetric
}
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
})
}
// GetBestInterface finds the best interface for reaching a destination,
// excluding the VPN interface to avoid routing loops.
//
// Route selection priority:
// 1. Longest prefix match (most specific route)
// 2. Lowest route metric
// 3. Lowest interface metric
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
var skipInterfaceIndex int
if vpnIntf != "" {
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
skipInterfaceIndex = iface.Index
} else {
return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err)
}
}
table, err := getWindowsRoutingTable()
if err != nil {
return nil, fmt.Errorf("get routing table: %w", err)
}
defer freeWindowsRoutingTable(table)
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
if len(candidates) == 0 {
return nil, fmt.Errorf("no route to %s", dest)
}
// Sort routes: prefix length -> route metric -> interface metric
sortRouteCandidates(candidates)
for _, candidate := range candidates {
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
if err != nil {
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
continue
}
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
continue
}
if iface.Flags&net.FlagUp == 0 {
log.Debugf("interface %s is down, trying next route", iface.Name)
continue
}
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
return iface, nil
}
return nil, fmt.Errorf("no usable interface found for %s", dest)
}
// formatRouteAge formats the route age in seconds to a human-readable string
func formatRouteAge(ageSeconds uint32) string {
if ageSeconds == 0 {

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
var (

View File

@@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
prefix := netip.PrefixFrom(addr, addr.BitLen())
return prefix, nil
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// Dial connects to the address on the named network.

View File

@@ -6,7 +6,7 @@ import (
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// ListenPacket listens for incoming packets on the given network and address.

49
client/net/conn.go Normal file
View File

@@ -0,0 +1,49 @@
//go:build !ios
package net
import (
"io"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/net/hooks"
)
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID hooks.ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
type TCPConn struct {
*net.TCPConn
ID hooks.ConnectionID
}
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}
// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}
return err
}

82
client/net/dial.go Normal file
View File

@@ -0,0 +1,82 @@
//go:build !ios
package net
import (
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
switch c := conn.(type) {
case *net.UDPConn:
// Advanced routing: plain connection
return c, nil
case *Conn:
// Legacy routing: wrapped connection preserves close hooks
udpConn, ok := c.Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
}
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
return nil, fmt.Errorf("unexpected connection type: %T", conn)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
switch c := conn.(type) {
case *net.TCPConn:
// Advanced routing: plain connection
return c, nil
case *Conn:
// Legacy routing: wrapped connection preserves close hooks
tcpConn, ok := c.Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
return nil, fmt.Errorf("unexpected connection type: %T", conn)
}

13
client/net/dial_ios.go Normal file
View File

@@ -0,0 +1,13 @@
package net
import (
"net"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}

20
client/net/dialer.go Normal file
View File

@@ -0,0 +1,20 @@
package net
import (
"net"
)
// Dialer extends the standard net.Dialer with the ability to execute hooks before
// and after connections. This can be used to bypass the VPN for connections using this dialer.
type Dialer struct {
*net.Dialer
}
// NewDialer returns a customized net.Dialer with overridden Control method
func NewDialer() *Dialer {
dialer := &Dialer{
Dialer: &net.Dialer{},
}
dialer.init()
return dialer
}

87
client/net/dialer_dial.go Normal file
View File

@@ -0,0 +1,87 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/net/hooks"
)
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
log.Debugf("Dialing %s %s", network, address)
if CustomRoutingDisabled() || AdvancedRouting() {
return d.Dialer.DialContext(ctx, network, address)
}
connID := hooks.GenerateConnID()
if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err)
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks
return &Conn{Conn: conn, ID: connID}, nil
}
// Dial wraps the net.Dialer's Dial method to use the custom connection
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error {
if ctx.Err() != nil {
return ctx.Err()
}
writeHooks := hooks.GetWriteHooks()
if len(writeHooks) == 0 {
return nil
}
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("split host and port: %w", err)
}
resolver := customResolver
if resolver == nil {
resolver = net.DefaultResolver
}
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
var merr *multierror.Error
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip.IP)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err))
continue
}
for _, hook := range writeHooks {
if err := hook(connID, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = ControlProtectSocket
}

View File

@@ -0,0 +1,7 @@
//go:build !linux && !windows
package net
func (d *Dialer) init() {
// implemented on Linux, Android, and Windows only
}

View File

@@ -0,0 +1,12 @@
//go:build !android
package net
import "syscall"
// init configures the net.Dialer Control function to set the fwmark on the socket
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
return setRawSocketMark(c)
}
}

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyUnicastIFToSocket
}

35
client/net/env.go Normal file
View File

@@ -0,0 +1,35 @@
package net
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
)
// CustomRoutingDisabled returns true if custom routing is disabled.
// This will fall back to the operation mode before the exit node functionality was implemented.
// In particular exclusion routes won't be set up and all dialers and listeners will use net.Dial and net.Listen, respectively.
func CustomRoutingDisabled() bool {
if netstack.IsEnabled() {
return true
}
var customRoutingDisabled bool
if val := os.Getenv(envDisableCustomRouting); val != "" {
var err error
customRoutingDisabled, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableCustomRouting, err)
}
}
return customRoutingDisabled
}

24
client/net/env_android.go Normal file
View File

@@ -0,0 +1,24 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

23
client/net/env_generic.go Normal file
View File

@@ -0,0 +1,23 @@
//go:build !linux && !windows && !android
package net
// Init initializes the network environment (no-op on non-Linux/Windows platforms)
func Init() {
// No-op on non-Linux/Windows platforms
}
// AdvancedRouting returns false on non-Linux/Windows platforms
func AdvancedRouting() bool {
return false
}
// SetVPNInterfaceName is a no-op on non-Windows platforms
func SetVPNInterfaceName(name string) {
// No-op on non-Windows platforms
}
// GetVPNInterfaceName returns empty string on non-Windows platforms
func GetVPNInterfaceName() string {
return ""
}

141
client/net/env_linux.go Normal file
View File

@@ -0,0 +1,141 @@
//go:build linux && !android
package net
import (
"errors"
"os"
"strconv"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
// these have the same effect, skip socket env supported for backward compatibility
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
)
var advancedRoutingSupported bool
func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
func AdvancedRouting() bool {
return advancedRoutingSupported
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
var skipSocketMark bool
if val := os.Getenv(envSkipSocketMark); val != "" {
skipSocketMark, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipSocketMark, err)
}
}
// requested to disable advanced routing
if legacyRouting || skipSocketMark ||
// envCustomRoutingDisabled disables the custom dialers.
// There is no point in using advanced routing without those, as they set up fwmarks on the sockets.
CustomRoutingDisabled() ||
// netstack mode doesn't need routing at all
netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}
if !CheckFwmarkSupport() || !CheckRuleOperationsSupport() {
log.Warn("system doesn't support required routing features, falling back to legacy routing")
return false
}
log.Info("system supports advanced routing")
return true
}
func CheckFwmarkSupport() bool {
// temporarily enable advanced routing to check if fwmarks are supported
old := advancedRoutingSupported
advancedRoutingSupported = true
defer func() {
advancedRoutingSupported = old
}()
dialer := NewDialer()
dialer.Timeout = 100 * time.Millisecond
conn, err := dialer.Dial("udp", "127.0.0.1:9")
if err != nil {
log.Warnf("failed to dial with fwmark: %v", err)
return false
}
defer func() {
if err := conn.Close(); err != nil {
log.Warnf("failed to close connection: %v", err)
}
}()
if err := conn.SetWriteDeadline(time.Now().Add(time.Millisecond * 100)); err != nil {
log.Warnf("failed to set write deadline: %v", err)
return false
}
if _, err := conn.Write([]byte("")); err != nil {
log.Warnf("failed to write to fwmark connection: %v", err)
return false
}
return true
}
func CheckRuleOperationsSupport() bool {
rule := netlink.NewRule()
// low precedence, semi-random
rule.Priority = 32321
rule.Table = syscall.RT_TABLE_MAIN
rule.Family = netlink.FAMILY_V4
if err := netlink.RuleAdd(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warn("IP rule operations are not supported")
return false
}
log.Warnf("failed to test rule support: %v", err)
return false
}
if err := netlink.RuleDel(rule); err != nil {
log.Warnf("failed to delete test rule: %v", err)
}
return true
}
// SetVPNInterfaceName is a no-op on Linux
func SetVPNInterfaceName(name string) {
// No-op on Linux - not needed for fwmark-based routing
}
// GetVPNInterfaceName returns empty string on Linux
func GetVPNInterfaceName() string {
return ""
}

67
client/net/env_windows.go Normal file
View File

@@ -0,0 +1,67 @@
//go:build windows
package net
import (
"os"
"strconv"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
var (
vpnInterfaceName string
vpnInitMutex sync.RWMutex
advancedRoutingSupported bool
)
func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}
log.Info("system supports advanced routing")
return true
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
func AdvancedRouting() bool {
return advancedRoutingSupported
}
// GetVPNInterfaceName returns the stored VPN interface name
func GetVPNInterfaceName() string {
vpnInitMutex.RLock()
defer vpnInitMutex.RUnlock()
return vpnInterfaceName
}
// SetVPNInterfaceName sets the VPN interface name for lazy initialization
func SetVPNInterfaceName(name string) {
vpnInitMutex.Lock()
defer vpnInitMutex.Unlock()
vpnInterfaceName = name
if name != "" {
log.Infof("VPN interface name set to %s for route exclusion", name)
}
}

93
client/net/hooks/hooks.go Normal file
View File

@@ -0,0 +1,93 @@
package hooks
import (
"net/netip"
"slices"
"sync"
"github.com/google/uuid"
)
// ConnectionID provides a globally unique identifier for network connections.
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string
// GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error
type CloseHookFunc func(connID ConnectionID) error
type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
var (
hooksMutex sync.RWMutex
writeHooks []WriteHookFunc
closeHooks []CloseHookFunc
addressRemoveHooks []AddressRemoveHookFunc
)
// AddWriteHook allows adding a new hook to be executed before writing/dialing.
func AddWriteHook(hook WriteHookFunc) {
hooksMutex.Lock()
defer hooksMutex.Unlock()
writeHooks = append(writeHooks, hook)
}
// AddCloseHook allows adding a new hook to be executed on connection close.
func AddCloseHook(hook CloseHookFunc) {
hooksMutex.Lock()
defer hooksMutex.Unlock()
closeHooks = append(closeHooks, hook)
}
// RemoveWriteHooks removes all write hooks.
func RemoveWriteHooks() {
hooksMutex.Lock()
defer hooksMutex.Unlock()
writeHooks = nil
}
// RemoveCloseHooks removes all close hooks.
func RemoveCloseHooks() {
hooksMutex.Lock()
defer hooksMutex.Unlock()
closeHooks = nil
}
// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed.
func AddAddressRemoveHook(hook AddressRemoveHookFunc) {
hooksMutex.Lock()
defer hooksMutex.Unlock()
addressRemoveHooks = append(addressRemoveHooks, hook)
}
// RemoveAddressRemoveHooks removes all listener address hooks.
func RemoveAddressRemoveHooks() {
hooksMutex.Lock()
defer hooksMutex.Unlock()
addressRemoveHooks = nil
}
// GetWriteHooks returns a copy of the current write hooks.
func GetWriteHooks() []WriteHookFunc {
hooksMutex.RLock()
defer hooksMutex.RUnlock()
return slices.Clone(writeHooks)
}
// GetCloseHooks returns a copy of the current close hooks.
func GetCloseHooks() []CloseHookFunc {
hooksMutex.RLock()
defer hooksMutex.RUnlock()
return slices.Clone(closeHooks)
}
// GetAddressRemoveHooks returns a copy of the current listener address remove hooks.
func GetAddressRemoveHooks() []AddressRemoveHookFunc {
hooksMutex.RLock()
defer hooksMutex.RUnlock()
return slices.Clone(addressRemoveHooks)
}

47
client/net/listen.go Normal file
View File

@@ -0,0 +1,47 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
switch c := conn.(type) {
case *net.UDPConn:
// Advanced routing: plain connection
return c, nil
case *PacketConn:
// Legacy routing: wrapped connection for hooks
udpConn, ok := c.PacketConn.(*net.UDPConn)
if !ok {
if err := c.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn)
}
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
return nil, fmt.Errorf("unexpected connection type: %T", conn)
}

11
client/net/listen_ios.go Normal file
View File

@@ -0,0 +1,11 @@
//go:build ios
package net
import (
"net"
)
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, laddr)
}

19
client/net/listener.go Normal file
View File

@@ -0,0 +1,19 @@
package net
import (
"net"
)
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
type ListenerConfig struct {
net.ListenConfig
}
// NewListener creates a new ListenerConfig instance.
func NewListener() *ListenerConfig {
listener := &ListenerConfig{}
listener.init()
return listener
}

View File

@@ -0,0 +1,6 @@
package net
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = ControlProtectSocket
}

View File

@@ -0,0 +1,7 @@
//go:build !linux && !windows
package net
func (l *ListenerConfig) init() {
// implemented on Linux, Android, and Windows only
}

View File

@@ -0,0 +1,14 @@
//go:build !android
package net
import (
"syscall"
)
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
return setRawSocketMark(c)
}
}

View File

@@ -0,0 +1,8 @@
package net
func (l *ListenerConfig) init() {
// TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses.
// For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case
// the interface will be selected that serves the default route.
l.ListenConfig.Control = applyUnicastIFToSocket
}

View File

@@ -0,0 +1,153 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/net/hooks"
)
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() || AdvancedRouting() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
}
connID := hooks.GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
}
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
type PacketConn struct {
net.PacketConn
ID hooks.ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
log.Errorf("Failed to call write hooks: %v", err)
}
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
}
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
type UDPConn struct {
*net.UDPConn
ID hooks.ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
log.Errorf("Failed to call write hooks: %v", err)
}
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)
}
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
func (c *PacketConn) RemoveAddress(addr string) {
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
return
}
ipStr, _, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("Error splitting IP address and port: %v", err)
return
}
ipAddr, err := netip.ParseAddr(ipStr)
if err != nil {
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
return
}
prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen())
addressRemoveHooks := hooks.GetAddressRemoveHooks()
if len(addressRemoveHooks) == 0 {
return
}
for _, hook := range addressRemoveHooks {
if err := hook(c.ID, prefix); err != nil {
log.Errorf("Error executing listener address remove hook: %v", err)
}
}
}
// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality
func WrapPacketConn(conn net.PacketConn) net.PacketConn {
if AdvancedRouting() {
// hooks not required for advanced routing
return conn
}
return &PacketConn{
PacketConn: conn,
ID: hooks.GenerateConnID(),
seenAddrs: &sync.Map{},
}
}
func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error {
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded {
return nil
}
writeHooks := hooks.GetWriteHooks()
if len(writeHooks) == 0 {
return nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr)
}
prefix, err := util.GetPrefixFromIP(udpAddr.IP)
if err != nil {
return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err)
}
log.Debugf("Listener resolved IP for %s: %s", addr, prefix)
var merr *multierror.Error
for _, hook := range writeHooks {
if err := hook(id, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,10 @@
package net
import (
"net"
)
// WrapPacketConn on iOS just returns the original connection since iOS handles its own networking
func WrapPacketConn(conn *net.UDPConn) *net.UDPConn {
return conn
}

69
client/net/net.go Normal file
View File

@@ -0,0 +1,69 @@
package net
import (
"fmt"
"math/big"
"net"
"net/netip"
)
const (
// ControlPlaneMark is the fwmark value used to mark packets that should not be routed through the NetBird interface to
// avoid routing loops.
// This includes all control plane traffic (mgmt, signal, flows), relay, ICE/stun/turn and everything that is emitted by the wireguard socket.
// It doesn't collide with the other marks, as the others are used for data plane traffic only.
ControlPlaneMark = 0x1BD00
// Data plane marks (0x1BD10 - 0x1BDFF)
// DataPlaneMarkLower is the lowest value for the data plane range
DataPlaneMarkLower = 0x1BD10
// DataPlaneMarkUpper is the highest value for the data plane range
DataPlaneMarkUpper = 0x1BDFF
// DataPlaneMarkIn is the mark for inbound data plane traffic.
DataPlaneMarkIn = 0x1BD10
// DataPlaneMarkOut is the mark for outbound data plane traffic.
DataPlaneMarkOut = 0x1BD11
// PreroutingFwmarkRedirected is applied to packets that are were redirected (input -> forward, e.g. by Docker or Podman) for special handling.
PreroutingFwmarkRedirected = 0x1BD20
// PreroutingFwmarkMasquerade is applied to packets that arrive from the NetBird interface and should be masqueraded.
PreroutingFwmarkMasquerade = 0x1BD21
// PreroutingFwmarkMasqueradeReturn is applied to packets that will leave through the NetBird interface and should be masqueraded.
PreroutingFwmarkMasqueradeReturn = 0x1BD22
)
// IsDataPlaneMark determines if a fwmark is in the data plane range (0x1BD10-0x1BDFF)
func IsDataPlaneMark(fwmark uint32) bool {
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
}
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
var endIP net.IP
addr := network.Addr().AsSlice()
mask := net.CIDRMask(network.Bits(), len(addr)*8)
for i := 0; i < len(addr); i++ {
endIP = append(endIP, addr[i]|^mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
ip, ok := netip.AddrFromSlice(resultInt.Bytes())
if !ok {
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
}
return ip.Unmap(), nil
}

55
client/net/net_linux.go Normal file
View File

@@ -0,0 +1,55 @@
//go:build !android
package net
import (
"fmt"
"syscall"
)
// SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error {
if !AdvancedRouting() {
return nil
}
sysconn, err := conn.SyscallConn()
if err != nil {
return fmt.Errorf("get raw conn: %w", err)
}
return setRawSocketMark(sysconn)
}
// SetSocketOpt sets the SO_MARK option on the given file descriptor
func SetSocketOpt(fd int) error {
if !AdvancedRouting() {
return nil
}
return setSocketOptInt(fd)
}
func setRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
if !AdvancedRouting() {
return
}
setErr = setSocketOptInt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
}
if setErr != nil {
return fmt.Errorf("set SO_MARK: %w", setErr)
}
return nil
}
func setSocketOptInt(fd int) error {
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, ControlPlaneMark)
}

94
client/net/net_test.go Normal file
View File

@@ -0,0 +1,94 @@
package net
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
name string
network string
fromEnd int
expected string
expectErr bool
}{
{
name: "IPv4 /24 network - last IP (fromEnd=0)",
network: "192.168.1.0/24",
fromEnd: 0,
expected: "192.168.1.255",
},
{
name: "IPv4 /24 network - fromEnd=1",
network: "192.168.1.0/24",
fromEnd: 1,
expected: "192.168.1.254",
},
{
name: "IPv4 /24 network - fromEnd=5",
network: "192.168.1.0/24",
fromEnd: 5,
expected: "192.168.1.250",
},
{
name: "IPv4 /16 network - last IP",
network: "10.0.0.0/16",
fromEnd: 0,
expected: "10.0.255.255",
},
{
name: "IPv4 /16 network - fromEnd=256",
network: "10.0.0.0/16",
fromEnd: 256,
expected: "10.0.254.255",
},
{
name: "IPv4 /32 network - single host",
network: "192.168.1.100/32",
fromEnd: 0,
expected: "192.168.1.100",
},
{
name: "IPv6 /64 network - last IP",
network: "2001:db8::/64",
fromEnd: 0,
expected: "2001:db8::ffff:ffff:ffff:ffff",
},
{
name: "IPv6 /64 network - fromEnd=1",
network: "2001:db8::/64",
fromEnd: 1,
expected: "2001:db8::ffff:ffff:ffff:fffe",
},
{
name: "IPv6 /128 network - single host",
network: "2001:db8::1/128",
fromEnd: 0,
expected: "2001:db8::1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
network, err := netip.ParsePrefix(tt.network)
require.NoError(t, err, "Failed to parse network prefix")
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
if tt.expectErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
expectedIP, err := netip.ParseAddr(tt.expected)
require.NoError(t, err, "Failed to parse expected IP")
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
})
}
}

284
client/net/net_windows.go Normal file
View File

@@ -0,0 +1,284 @@
package net
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
IpUnicastIf = 31
Ipv6UnicastIf = 31
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options
Ipv6V6only = 27
)
// GetBestInterfaceFunc is set at runtime to avoid import cycle
var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
// nativeToBigEndian converts a uint32 from native byte order to big-endian
func nativeToBigEndian(v uint32) uint32 {
return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24
}
// parseDestinationAddress parses the destination address from various formats
func parseDestinationAddress(network, address string) (netip.Addr, error) {
if address == "" {
if strings.HasSuffix(network, "6") {
return netip.IPv6Unspecified(), nil
}
return netip.IPv4Unspecified(), nil
}
if addrPort, err := netip.ParseAddrPort(address); err == nil {
return addrPort.Addr(), nil
}
if dest, err := netip.ParseAddr(address); err == nil {
return dest, nil
}
host, _, err := net.SplitHostPort(address)
if err != nil {
// No port, treat whole string as host
host = address
}
if host == "" {
if strings.HasSuffix(network, "6") {
return netip.IPv6Unspecified(), nil
}
return netip.IPv4Unspecified(), nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil || len(ips) == 0 {
return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
}
dest, ok := netip.AddrFromSlice(ips[0].IP)
if !ok {
return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
}
if ips[0].Zone != "" {
dest = dest.WithZone(ips[0].Zone)
}
return dest, nil
}
func getInterfaceFromZone(zone string) *net.Interface {
if zone == "" {
return nil
}
idx, err := strconv.Atoi(zone)
if err != nil {
log.Debugf("invalid zone format for Windows (expected numeric): %s", zone)
return nil
}
iface, err := net.InterfaceByIndex(idx)
if err != nil {
log.Debugf("failed to get interface by index %d from zone: %v", idx, err)
return nil
}
return iface
}
type interfaceSelection struct {
iface4 *net.Interface
iface6 *net.Interface
}
func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection {
iface := getInterfaceFromZone(zone)
if iface == nil {
return nil
}
if dest.Is6() {
return &interfaceSelection{iface6: iface}
}
return &interfaceSelection{iface4: iface}
}
func selectInterfaceForUnspecified() (*interfaceSelection, error) {
if GetBestInterfaceFunc == nil {
return nil, errors.New("GetBestInterfaceFunc not initialized")
}
var result interfaceSelection
vpnIfaceName := GetVPNInterfaceName()
if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil {
result.iface4 = iface4
} else {
log.Debugf("No IPv4 default route found: %v", err)
}
if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil {
result.iface6 = iface6
} else {
log.Debugf("No IPv6 default route found: %v", err)
}
if result.iface4 == nil && result.iface6 == nil {
return nil, errors.New("no default routes found")
}
return &result, nil
}
func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
if zone := dest.Zone(); zone != "" {
if selection := selectInterfaceForZone(dest, zone); selection != nil {
return selection, nil
}
}
if dest.IsUnspecified() {
return selectInterfaceForUnspecified()
}
if GetBestInterfaceFunc == nil {
return nil, errors.New("GetBestInterfaceFunc not initialized")
}
iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
if err != nil {
return nil, fmt.Errorf("find route for %s: %w", dest, err)
}
if dest.Is6() {
return &interfaceSelection{iface6: iface}, nil
}
return &interfaceSelection{iface4: iface}, nil
}
func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error {
ifaceIndexBE := nativeToBigEndian(uint32(iface.Index))
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil {
return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil {
return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
// The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.)
// Never generic ones (udp, tcp, ip)
switch {
case strings.HasSuffix(network, "4"):
// IPv4-only socket (udp4, tcp4, ip4)
return setUnicastIfIPv4(fd, network, selection, address)
case strings.HasSuffix(network, "6"):
// IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only
return setUnicastIfIPv6(fd, network, selection, address)
}
// Shouldn't reach here based on Go's documented behavior
return fmt.Errorf("unexpected network type: %s", network)
}
func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error {
if selection.iface4 == nil {
return nil
}
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
return err
}
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
return nil
}
func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error {
isDualStack := checkDualStack(fd)
// For dual-stack sockets, also set the IPv4 option
if isDualStack && selection.iface4 != nil {
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
return err
}
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address)
}
if selection.iface6 == nil {
return nil
}
if err := setIPv6UnicastIF(fd, selection.iface6); err != nil {
return err
}
log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
return nil
}
func checkDualStack(fd uintptr) bool {
var v6Only int
v6OnlyLen := int32(unsafe.Sizeof(v6Only))
err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen)
return err == nil && v6Only == 0
}
// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address
func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
dest, err := parseDestinationAddress(network, address)
if err != nil {
return err
}
dest = dest.Unmap()
if !dest.IsValid() {
return fmt.Errorf("invalid destination address for %s", address)
}
selection, err := selectInterface(dest)
if err != nil {
return err
}
var controlErr error
err = c.Control(func(fd uintptr) {
controlErr = setUnicastIf(fd, network, selection, address)
})
if err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -0,0 +1,47 @@
package net
import (
"fmt"
"sync"
"syscall"
"github.com/netbirdio/netbird/client/iface/netstack"
)
var (
androidProtectSocketLock sync.Mutex
androidProtectSocket func(fd int32) bool
)
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
androidProtectSocketLock.Lock()
androidProtectSocket = fn
androidProtectSocketLock.Unlock()
}
// ControlProtectSocket is a Control function that sets the fwmark on the socket
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
if netstack.IsEnabled() {
return nil
}
var aErr error
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
defer androidProtectSocketLock.Unlock()
if androidProtectSocket == nil {
aErr = fmt.Errorf("socket protection function not set")
return
}
if !androidProtectSocket(int32(fd)) {
aErr = fmt.Errorf("failed to protect socket via Android")
}
})
if err != nil {
return err
}
return aErr
}