[client] Use native windows sock opts to avoid routing loops (#4314)

- Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed
 - Add methods to return the next best interface after the NetBird interface.
- Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method.
- Some refactoring to avoid import cycles
- Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var
This commit is contained in:
Viktor Liu
2025-09-20 09:31:04 +02:00
committed by GitHub
parent 90577682e4
commit 55126f990c
77 changed files with 1180 additions and 606 deletions

View File

@@ -1,94 +0,0 @@
package grpc
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os/user"
"runtime"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
b.Clock = backoff.SystemClock
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool,
}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr,
transportOption,
WithCustomDialer(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
}),
)
if err != nil {
log.Printf("DialContext error: %v", err)
return nil, err
}
return conn, nil
}

View File

@@ -1,31 +0,0 @@
//go:build !ios
package net
import (
"net"
log "github.com/sirupsen/logrus"
)
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}

View File

@@ -1,58 +0,0 @@
//go:build !ios
package net
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
}
return tcpConn, nil
}

View File

@@ -1,13 +0,0 @@
package net
import (
"net"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
}

View File

@@ -1,21 +0,0 @@
package net
import (
"net"
)
// Dialer extends the standard net.Dialer with the ability to execute hooks before
// and after connections. This can be used to bypass the VPN for connections using this dialer.
type Dialer struct {
*net.Dialer
}
// NewDialer returns a customized net.Dialer with overridden Control method
func NewDialer() *Dialer {
dialer := &Dialer{
Dialer: &net.Dialer{},
}
dialer.init()
return dialer
}

View File

@@ -1,107 +0,0 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
)
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
var (
dialerDialHooksMutex sync.RWMutex
dialerDialHooks []DialerDialHookFunc
dialerCloseHooksMutex sync.RWMutex
dialerCloseHooks []DialerCloseHookFunc
)
// AddDialerHook allows adding a new hook to be executed before dialing.
func AddDialerHook(hook DialerDialHookFunc) {
dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock()
dialerDialHooks = append(dialerDialHooks, hook)
}
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
func AddDialerCloseHook(hook DialerCloseHookFunc) {
dialerCloseHooksMutex.Lock()
defer dialerCloseHooksMutex.Unlock()
dialerCloseHooks = append(dialerCloseHooks, hook)
}
// RemoveDialerHooks removes all dialer hooks.
func RemoveDialerHooks() {
dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock()
dialerDialHooks = nil
dialerCloseHooksMutex.Lock()
defer dialerCloseHooksMutex.Unlock()
dialerCloseHooks = nil
}
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
log.Debugf("Dialing %s %s", network, address)
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
}
connID := GenerateConnID()
if dialerDialHooks != nil {
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err)
}
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks
return &Conn{Conn: conn, ID: connID}, nil
}
// Dial wraps the net.Dialer's Dial method to use the custom connection
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("split host and port: %w", err)
}
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
var result *multierror.Error
dialerDialHooksMutex.RLock()
defer dialerDialHooksMutex.RUnlock()
for _, hook := range dialerDialHooks {
if err := hook(ctx, connID, ips); err != nil {
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
}
}
return result.ErrorOrNil()
}

View File

@@ -1,5 +0,0 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = ControlProtectSocket
}

View File

@@ -1,12 +0,0 @@
//go:build !android
package net
import "syscall"
// init configures the net.Dialer Control function to set the fwmark on the socket
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
return setRawSocketMark(c)
}
}

View File

@@ -1,7 +0,0 @@
//go:build !linux
package net
func (d *Dialer) init() {
// implemented on Linux and Android only
}

View File

@@ -1,34 +0,0 @@
package net
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
)
// CustomRoutingDisabled returns true if custom routing is disabled.
// This will fall back to the operation mode before the exit node functionality was implemented.
// In particular exclusion routes won't be set up and all dialers and listeners will use net.Dial and net.Listen, respectively.
func CustomRoutingDisabled() bool {
if netstack.IsEnabled() {
return true
}
var customRoutingDisabled bool
if val := os.Getenv(envDisableCustomRouting); val != "" {
var err error
customRoutingDisabled, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableCustomRouting, err)
}
}
return customRoutingDisabled
}

View File

@@ -1,12 +0,0 @@
//go:build !linux || android
package net
func Init() {
// nothing to do on non-linux
}
func AdvancedRouting() bool {
// non-linux currently doesn't support advanced routing
return false
}

View File

@@ -1,131 +0,0 @@
//go:build linux && !android
package net
import (
"errors"
"os"
"strconv"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const (
// these have the same effect, skip socket env supported for backward compatibility
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
)
var advancedRoutingSupported bool
func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
func AdvancedRouting() bool {
return advancedRoutingSupported
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
var skipSocketMark bool
if val := os.Getenv(envSkipSocketMark); val != "" {
skipSocketMark, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipSocketMark, err)
}
}
// requested to disable advanced routing
if legacyRouting || skipSocketMark ||
// envCustomRoutingDisabled disables the custom dialers.
// There is no point in using advanced routing without those, as they set up fwmarks on the sockets.
CustomRoutingDisabled() ||
// netstack mode doesn't need routing at all
netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}
if !CheckFwmarkSupport() || !CheckRuleOperationsSupport() {
log.Warn("system doesn't support required routing features, falling back to legacy routing")
return false
}
log.Info("system supports advanced routing")
return true
}
func CheckFwmarkSupport() bool {
// temporarily enable advanced routing to check fwmarks are supported
old := advancedRoutingSupported
advancedRoutingSupported = true
defer func() {
advancedRoutingSupported = old
}()
dialer := NewDialer()
dialer.Timeout = 100 * time.Millisecond
conn, err := dialer.Dial("udp", "127.0.0.1:9")
if err != nil {
log.Warnf("failed to dial with fwmark: %v", err)
return false
}
defer func() {
if err := conn.Close(); err != nil {
log.Warnf("failed to close connection: %v", err)
}
}()
if err := conn.SetWriteDeadline(time.Now().Add(time.Millisecond * 100)); err != nil {
log.Warnf("failed to set write deadline: %v", err)
return false
}
if _, err := conn.Write([]byte("")); err != nil {
log.Warnf("failed to write to fwmark connection: %v", err)
return false
}
return true
}
func CheckRuleOperationsSupport() bool {
rule := netlink.NewRule()
// low precedence, semi-random
rule.Priority = 32321
rule.Table = syscall.RT_TABLE_MAIN
rule.Family = netlink.FAMILY_V4
if err := netlink.RuleAdd(rule); err != nil {
if errors.Is(err, syscall.EOPNOTSUPP) {
log.Warn("IP rule operations are not supported")
return false
}
log.Warnf("failed to test rule support: %v", err)
return false
}
if err := netlink.RuleDel(rule); err != nil {
log.Warnf("failed to delete test rule: %v", err)
}
return true
}

View File

@@ -1,37 +0,0 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
packetConn := conn.(*PacketConn)
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
if !ok {
if err := packetConn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
}
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -1,11 +0,0 @@
//go:build ios
package net
import (
"net"
)
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, laddr)
}

View File

@@ -1,21 +0,0 @@
package net
import (
"net"
)
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
type ListenerConfig struct {
*net.ListenConfig
}
// NewListener creates a new ListenerConfig instance.
func NewListener() *ListenerConfig {
listener := &ListenerConfig{
ListenConfig: &net.ListenConfig{},
}
listener.init()
return listener
}

View File

@@ -1,6 +0,0 @@
package net
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = ControlProtectSocket
}

View File

@@ -1,14 +0,0 @@
//go:build !android
package net
import (
"syscall"
)
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
return setRawSocketMark(c)
}
}

View File

@@ -1,7 +0,0 @@
//go:build !linux
package net
func (l *ListenerConfig) init() {
// implemented on Linux and Android only
}

View File

@@ -1,205 +0,0 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
)
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
var (
listenerWriteHooksMutex sync.RWMutex
listenerWriteHooks []ListenerWriteHookFunc
listenerCloseHooksMutex sync.RWMutex
listenerCloseHooks []ListenerCloseHookFunc
listenerAddressRemoveHooksMutex sync.RWMutex
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
)
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock()
listenerWriteHooks = append(listenerWriteHooks, hook)
}
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = append(listenerCloseHooks, hook)
}
// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
listenerAddressRemoveHooksMutex.Lock()
defer listenerAddressRemoveHooksMutex.Unlock()
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
}
// RemoveListenerHooks removes all listener hooks.
func RemoveListenerHooks() {
listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock()
listenerWriteHooks = nil
listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = nil
listenerAddressRemoveHooksMutex.Lock()
defer listenerAddressRemoveHooksMutex.Unlock()
listenerAddressRemoveHooks = nil
}
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
}
connID := GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
}
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
type PacketConn struct {
net.PacketConn
ID ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
callWriteHooks(c.ID, c.seenAddrs, b, addr)
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
func (c *PacketConn) Close() error {
c.seenAddrs = &sync.Map{}
return closeConn(c.ID, c.PacketConn)
}
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
type UDPConn struct {
*net.UDPConn
ID ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
callWriteHooks(c.ID, c.seenAddrs, b, addr)
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
func (c *UDPConn) Close() error {
c.seenAddrs = &sync.Map{}
return closeConn(c.ID, c.UDPConn)
}
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
func (c *PacketConn) RemoveAddress(addr string) {
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
return
}
ipStr, _, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("Error splitting IP address and port: %v", err)
return
}
ipAddr, err := netip.ParseAddr(ipStr)
if err != nil {
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
return
}
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
listenerAddressRemoveHooksMutex.RLock()
defer listenerAddressRemoveHooksMutex.RUnlock()
for _, hook := range listenerAddressRemoveHooks {
if err := hook(c.ID, prefix); err != nil {
log.Errorf("Error executing listener address remove hook: %v", err)
}
}
}
// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality
func WrapPacketConn(conn net.PacketConn) *PacketConn {
return &PacketConn{
PacketConn: conn,
ID: GenerateConnID(),
seenAddrs: &sync.Map{},
}
}
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
ipStr, _, splitErr := net.SplitHostPort(addr.String())
if splitErr != nil {
log.Errorf("Error splitting IP address and port: %v", splitErr)
return
}
ip, err := net.ResolveIPAddr("ip", ipStr)
if err != nil {
log.Errorf("Error resolving IP address: %v", err)
return
}
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
func() {
listenerWriteHooksMutex.RLock()
defer listenerWriteHooksMutex.RUnlock()
for _, hook := range listenerWriteHooks {
if err := hook(id, ip, b); err != nil {
log.Errorf("Error executing listener write hook: %v", err)
}
}
}()
}
}
func closeConn(id ConnectionID, conn net.PacketConn) error {
err := conn.Close()
listenerCloseHooksMutex.RLock()
defer listenerCloseHooksMutex.RUnlock()
for _, hook := range listenerCloseHooks {
if err := hook(id, conn); err != nil {
log.Errorf("Error executing listener close hook: %v", err)
}
}
return err
}

View File

@@ -1,10 +0,0 @@
package net
import (
"net"
)
// WrapPacketConn on iOS just returns the original connection since iOS handles its own networking
func WrapPacketConn(conn *net.UDPConn) *net.UDPConn {
return conn
}

View File

@@ -1,83 +0,0 @@
package net
import (
"fmt"
"math/big"
"net"
"net/netip"
"github.com/google/uuid"
)
const (
// ControlPlaneMark is the fwmark value used to mark packets that should not be routed through the NetBird interface to
// avoid routing loops.
// This includes all control plane traffic (mgmt, signal, flows), relay, ICE/stun/turn and everything that is emitted by the wireguard socket.
// It doesn't collide with the other marks, as the others are used for data plane traffic only.
ControlPlaneMark = 0x1BD00
// Data plane marks (0x1BD10 - 0x1BDFF)
// DataPlaneMarkLower is the lowest value for the data plane range
DataPlaneMarkLower = 0x1BD10
// DataPlaneMarkUpper is the highest value for the data plane range
DataPlaneMarkUpper = 0x1BDFF
// DataPlaneMarkIn is the mark for inbound data plane traffic.
DataPlaneMarkIn = 0x1BD10
// DataPlaneMarkOut is the mark for outbound data plane traffic.
DataPlaneMarkOut = 0x1BD11
// PreroutingFwmarkRedirected is applied to packets that are were redirected (input -> forward, e.g. by Docker or Podman) for special handling.
PreroutingFwmarkRedirected = 0x1BD20
// PreroutingFwmarkMasquerade is applied to packets that arrive from the NetBird interface and should be masqueraded.
PreroutingFwmarkMasquerade = 0x1BD21
// PreroutingFwmarkMasqueradeReturn is applied to packets that will leave through the NetBird interface and should be masqueraded.
PreroutingFwmarkMasqueradeReturn = 0x1BD22
)
// IsDataPlaneMark determines if a fwmark is in the data plane range (0x1BD10-0x1BDFF)
func IsDataPlaneMark(fwmark uint32) bool {
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
}
// ConnectionID provides a globally unique identifier for network connections.
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string
type AddHookFunc func(connID ConnectionID, IP net.IP) error
type RemoveHookFunc func(connID ConnectionID) error
// GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
var endIP net.IP
addr := network.Addr().AsSlice()
mask := net.CIDRMask(network.Bits(), len(addr)*8)
for i := 0; i < len(addr); i++ {
endIP = append(endIP, addr[i]|^mask[i])
}
// convert to big.Int
endInt := big.NewInt(0)
endInt.SetBytes(endIP)
// subtract fromEnd from the last ip
fromEndBig := big.NewInt(int64(fromEnd))
resultInt := big.NewInt(0)
resultInt.Sub(endInt, fromEndBig)
ip, ok := netip.AddrFromSlice(resultInt.Bytes())
if !ok {
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
}
return ip.Unmap(), nil
}

View File

@@ -1,55 +0,0 @@
//go:build !android
package net
import (
"fmt"
"syscall"
)
// SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error {
if !AdvancedRouting() {
return nil
}
sysconn, err := conn.SyscallConn()
if err != nil {
return fmt.Errorf("get raw conn: %w", err)
}
return setRawSocketMark(sysconn)
}
// SetSocketOpt sets the SO_MARK option on the given file descriptor
func SetSocketOpt(fd int) error {
if !AdvancedRouting() {
return nil
}
return setSocketOptInt(fd)
}
func setRawSocketMark(conn syscall.RawConn) error {
var setErr error
err := conn.Control(func(fd uintptr) {
if !AdvancedRouting() {
return
}
setErr = setSocketOptInt(int(fd))
})
if err != nil {
return fmt.Errorf("control: %w", err)
}
if setErr != nil {
return fmt.Errorf("set SO_MARK: %w", setErr)
}
return nil
}
func setSocketOptInt(fd int) error {
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, ControlPlaneMark)
}

View File

@@ -1,94 +0,0 @@
package net
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
name string
network string
fromEnd int
expected string
expectErr bool
}{
{
name: "IPv4 /24 network - last IP (fromEnd=0)",
network: "192.168.1.0/24",
fromEnd: 0,
expected: "192.168.1.255",
},
{
name: "IPv4 /24 network - fromEnd=1",
network: "192.168.1.0/24",
fromEnd: 1,
expected: "192.168.1.254",
},
{
name: "IPv4 /24 network - fromEnd=5",
network: "192.168.1.0/24",
fromEnd: 5,
expected: "192.168.1.250",
},
{
name: "IPv4 /16 network - last IP",
network: "10.0.0.0/16",
fromEnd: 0,
expected: "10.0.255.255",
},
{
name: "IPv4 /16 network - fromEnd=256",
network: "10.0.0.0/16",
fromEnd: 256,
expected: "10.0.254.255",
},
{
name: "IPv4 /32 network - single host",
network: "192.168.1.100/32",
fromEnd: 0,
expected: "192.168.1.100",
},
{
name: "IPv6 /64 network - last IP",
network: "2001:db8::/64",
fromEnd: 0,
expected: "2001:db8::ffff:ffff:ffff:ffff",
},
{
name: "IPv6 /64 network - fromEnd=1",
network: "2001:db8::/64",
fromEnd: 1,
expected: "2001:db8::ffff:ffff:ffff:fffe",
},
{
name: "IPv6 /128 network - single host",
network: "2001:db8::1/128",
fromEnd: 0,
expected: "2001:db8::1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
network, err := netip.ParsePrefix(tt.network)
require.NoError(t, err, "Failed to parse network prefix")
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
if tt.expectErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
expectedIP, err := netip.ParseAddr(tt.expected)
require.NoError(t, err, "Failed to parse expected IP")
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
})
}
}

View File

@@ -1,47 +0,0 @@
package net
import (
"fmt"
"sync"
"syscall"
"github.com/netbirdio/netbird/client/iface/netstack"
)
var (
androidProtectSocketLock sync.Mutex
androidProtectSocket func(fd int32) bool
)
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
androidProtectSocketLock.Lock()
androidProtectSocket = fn
androidProtectSocketLock.Unlock()
}
// ControlProtectSocket is a Control function that sets the fwmark on the socket
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
if netstack.IsEnabled() {
return nil
}
var aErr error
err := c.Control(func(fd uintptr) {
androidProtectSocketLock.Lock()
defer androidProtectSocketLock.Unlock()
if androidProtectSocket == nil {
aErr = fmt.Errorf("socket protection function not set")
return
}
if !androidProtectSocket(int32(fd)) {
aErr = fmt.Errorf("failed to protect socket via Android")
}
})
if err != nil {
return err
}
return aErr
}