mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Use native windows sock opts to avoid routing loops (#4314)
- Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed - Add methods to return the next best interface after the NetBird interface. - Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method. - Some refactoring to avoid import cycles - Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var
This commit is contained in:
@@ -1,94 +0,0 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
|
||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxElapsedTime = 10 * time.Second
|
||||
b.Clock = backoff.SystemClock
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("DialContext error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
|
||||
dialerCloseHooksMutex.RLock()
|
||||
defer dialerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range dialerCloseHooks {
|
||||
if err := hook(c.ID, &c.Conn); err != nil {
|
||||
log.Errorf("Error executing dialer close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// Dialer extends the standard net.Dialer with the ability to execute hooks before
|
||||
// and after connections. This can be used to bypass the VPN for connections using this dialer.
|
||||
type Dialer struct {
|
||||
*net.Dialer
|
||||
}
|
||||
|
||||
// NewDialer returns a customized net.Dialer with overridden Control method
|
||||
func NewDialer() *Dialer {
|
||||
dialer := &Dialer{
|
||||
Dialer: &net.Dialer{},
|
||||
}
|
||||
dialer.init()
|
||||
|
||||
return dialer
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
|
||||
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
|
||||
|
||||
var (
|
||||
dialerDialHooksMutex sync.RWMutex
|
||||
dialerDialHooks []DialerDialHookFunc
|
||||
dialerCloseHooksMutex sync.RWMutex
|
||||
dialerCloseHooks []DialerCloseHookFunc
|
||||
)
|
||||
|
||||
// AddDialerHook allows adding a new hook to be executed before dialing.
|
||||
func AddDialerHook(hook DialerDialHookFunc) {
|
||||
dialerDialHooksMutex.Lock()
|
||||
defer dialerDialHooksMutex.Unlock()
|
||||
dialerDialHooks = append(dialerDialHooks, hook)
|
||||
}
|
||||
|
||||
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
|
||||
func AddDialerCloseHook(hook DialerCloseHookFunc) {
|
||||
dialerCloseHooksMutex.Lock()
|
||||
defer dialerCloseHooksMutex.Unlock()
|
||||
dialerCloseHooks = append(dialerCloseHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveDialerHooks removes all dialer hooks.
|
||||
func RemoveDialerHooks() {
|
||||
dialerDialHooksMutex.Lock()
|
||||
defer dialerDialHooksMutex.Unlock()
|
||||
dialerDialHooks = nil
|
||||
|
||||
dialerCloseHooksMutex.Lock()
|
||||
defer dialerCloseHooksMutex.Unlock()
|
||||
dialerCloseHooks = nil
|
||||
}
|
||||
|
||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Dialing %s %s", network, address)
|
||||
|
||||
if CustomRoutingDisabled() {
|
||||
return d.Dialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
var resolver *net.Resolver
|
||||
if d.Resolver != nil {
|
||||
resolver = d.Resolver
|
||||
}
|
||||
|
||||
connID := GenerateConnID()
|
||||
if dialerDialHooks != nil {
|
||||
if err := callDialerHooks(ctx, connID, address, resolver); err != nil {
|
||||
log.Errorf("Failed to call dialer hooks: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
return &Conn{Conn: conn, ID: connID}, nil
|
||||
}
|
||||
|
||||
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("split host and port: %w", err)
|
||||
}
|
||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||
}
|
||||
|
||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||
|
||||
var result *multierror.Error
|
||||
|
||||
dialerDialHooksMutex.RLock()
|
||||
defer dialerDialHooksMutex.RUnlock()
|
||||
for _, hook := range dialerDialHooks {
|
||||
if err := hook(ctx, connID, ips); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = ControlProtectSocket
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package net
|
||||
|
||||
import "syscall"
|
||||
|
||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
return setRawSocketMark(c)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
)
|
||||
|
||||
// CustomRoutingDisabled returns true if custom routing is disabled.
|
||||
// This will fall back to the operation mode before the exit node functionality was implemented.
|
||||
// In particular exclusion routes won't be set up and all dialers and listeners will use net.Dial and net.Listen, respectively.
|
||||
func CustomRoutingDisabled() bool {
|
||||
if netstack.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
|
||||
var customRoutingDisabled bool
|
||||
if val := os.Getenv(envDisableCustomRouting); val != "" {
|
||||
var err error
|
||||
customRoutingDisabled, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envDisableCustomRouting, err)
|
||||
}
|
||||
}
|
||||
|
||||
return customRoutingDisabled
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package net
|
||||
|
||||
func Init() {
|
||||
// nothing to do on non-linux
|
||||
}
|
||||
|
||||
func AdvancedRouting() bool {
|
||||
// non-linux currently doesn't support advanced routing
|
||||
return false
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
const (
|
||||
// these have the same effect, skip socket env supported for backward compatibility
|
||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
||||
)
|
||||
|
||||
var advancedRoutingSupported bool
|
||||
|
||||
func Init() {
|
||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||
}
|
||||
|
||||
func AdvancedRouting() bool {
|
||||
return advancedRoutingSupported
|
||||
}
|
||||
|
||||
func checkAdvancedRoutingSupport() bool {
|
||||
var err error
|
||||
|
||||
var legacyRouting bool
|
||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||
legacyRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||
}
|
||||
}
|
||||
|
||||
var skipSocketMark bool
|
||||
if val := os.Getenv(envSkipSocketMark); val != "" {
|
||||
skipSocketMark, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envSkipSocketMark, err)
|
||||
}
|
||||
}
|
||||
|
||||
// requested to disable advanced routing
|
||||
if legacyRouting || skipSocketMark ||
|
||||
// envCustomRoutingDisabled disables the custom dialers.
|
||||
// There is no point in using advanced routing without those, as they set up fwmarks on the sockets.
|
||||
CustomRoutingDisabled() ||
|
||||
// netstack mode doesn't need routing at all
|
||||
netstack.IsEnabled() {
|
||||
|
||||
log.Info("advanced routing has been requested to be disabled")
|
||||
return false
|
||||
}
|
||||
|
||||
if !CheckFwmarkSupport() || !CheckRuleOperationsSupport() {
|
||||
log.Warn("system doesn't support required routing features, falling back to legacy routing")
|
||||
return false
|
||||
}
|
||||
|
||||
log.Info("system supports advanced routing")
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckFwmarkSupport() bool {
|
||||
// temporarily enable advanced routing to check fwmarks are supported
|
||||
old := advancedRoutingSupported
|
||||
advancedRoutingSupported = true
|
||||
defer func() {
|
||||
advancedRoutingSupported = old
|
||||
}()
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.Timeout = 100 * time.Millisecond
|
||||
|
||||
conn, err := dialer.Dial("udp", "127.0.0.1:9")
|
||||
if err != nil {
|
||||
log.Warnf("failed to dial with fwmark: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warnf("failed to close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(time.Millisecond * 100)); err != nil {
|
||||
log.Warnf("failed to set write deadline: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte("")); err != nil {
|
||||
log.Warnf("failed to write to fwmark connection: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func CheckRuleOperationsSupport() bool {
|
||||
rule := netlink.NewRule()
|
||||
// low precedence, semi-random
|
||||
rule.Priority = 32321
|
||||
rule.Table = syscall.RT_TABLE_MAIN
|
||||
rule.Family = netlink.FAMILY_V4
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil {
|
||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
log.Warn("IP rule operations are not supported")
|
||||
return false
|
||||
}
|
||||
log.Warnf("failed to test rule support: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil {
|
||||
log.Warnf("failed to delete test rule: %v", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
}
|
||||
|
||||
packetConn := conn.(*PacketConn)
|
||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := packetConn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||
}
|
||||
|
||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
||||
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
||||
type ListenerConfig struct {
|
||||
*net.ListenConfig
|
||||
}
|
||||
|
||||
// NewListener creates a new ListenerConfig instance.
|
||||
func NewListener() *ListenerConfig {
|
||||
listener := &ListenerConfig{
|
||||
ListenConfig: &net.ListenConfig{},
|
||||
}
|
||||
listener.init()
|
||||
|
||||
return listener
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
package net
|
||||
|
||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = ControlProtectSocket
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
return setRawSocketMark(c)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
||||
@@ -1,205 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
|
||||
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
|
||||
|
||||
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
||||
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
||||
|
||||
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
|
||||
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
|
||||
var (
|
||||
listenerWriteHooksMutex sync.RWMutex
|
||||
listenerWriteHooks []ListenerWriteHookFunc
|
||||
listenerCloseHooksMutex sync.RWMutex
|
||||
listenerCloseHooks []ListenerCloseHookFunc
|
||||
listenerAddressRemoveHooksMutex sync.RWMutex
|
||||
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
|
||||
)
|
||||
|
||||
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
||||
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
|
||||
listenerWriteHooksMutex.Lock()
|
||||
defer listenerWriteHooksMutex.Unlock()
|
||||
listenerWriteHooks = append(listenerWriteHooks, hook)
|
||||
}
|
||||
|
||||
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
|
||||
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
|
||||
listenerCloseHooksMutex.Lock()
|
||||
defer listenerCloseHooksMutex.Unlock()
|
||||
listenerCloseHooks = append(listenerCloseHooks, hook)
|
||||
}
|
||||
|
||||
// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
|
||||
listenerAddressRemoveHooksMutex.Lock()
|
||||
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveListenerHooks removes all listener hooks.
|
||||
func RemoveListenerHooks() {
|
||||
listenerWriteHooksMutex.Lock()
|
||||
defer listenerWriteHooksMutex.Unlock()
|
||||
listenerWriteHooks = nil
|
||||
|
||||
listenerCloseHooksMutex.Lock()
|
||||
defer listenerCloseHooksMutex.Unlock()
|
||||
listenerCloseHooks = nil
|
||||
|
||||
listenerAddressRemoveHooksMutex.Lock()
|
||||
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||
listenerAddressRemoveHooks = nil
|
||||
}
|
||||
|
||||
// ListenPacket listens on the network address and returns a PacketConn
|
||||
// which includes support for write hooks.
|
||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
}
|
||||
|
||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen packet: %w", err)
|
||||
}
|
||||
connID := GenerateConnID()
|
||||
|
||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type PacketConn struct {
|
||||
net.PacketConn
|
||||
ID ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
||||
return c.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *PacketConn) Close() error {
|
||||
c.seenAddrs = &sync.Map{}
|
||||
return closeConn(c.ID, c.PacketConn)
|
||||
}
|
||||
|
||||
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
||||
type UDPConn struct {
|
||||
*net.UDPConn
|
||||
ID ConnectionID
|
||||
seenAddrs *sync.Map
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
||||
return c.UDPConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||
func (c *UDPConn) Close() error {
|
||||
c.seenAddrs = &sync.Map{}
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
}
|
||||
|
||||
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||
func (c *PacketConn) RemoveAddress(addr string) {
|
||||
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||
return
|
||||
}
|
||||
|
||||
ipStr, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
|
||||
|
||||
listenerAddressRemoveHooksMutex.RLock()
|
||||
defer listenerAddressRemoveHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerAddressRemoveHooks {
|
||||
if err := hook(c.ID, prefix); err != nil {
|
||||
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality
|
||||
func WrapPacketConn(conn net.PacketConn) *PacketConn {
|
||||
return &PacketConn{
|
||||
PacketConn: conn,
|
||||
ID: GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
||||
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
||||
ipStr, _, splitErr := net.SplitHostPort(addr.String())
|
||||
if splitErr != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", splitErr)
|
||||
return
|
||||
}
|
||||
|
||||
ip, err := net.ResolveIPAddr("ip", ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error resolving IP address: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
|
||||
|
||||
func() {
|
||||
listenerWriteHooksMutex.RLock()
|
||||
defer listenerWriteHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerWriteHooks {
|
||||
if err := hook(id, ip, b); err != nil {
|
||||
log.Errorf("Error executing listener write hook: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func closeConn(id ConnectionID, conn net.PacketConn) error {
|
||||
err := conn.Close()
|
||||
|
||||
listenerCloseHooksMutex.RLock()
|
||||
defer listenerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerCloseHooks {
|
||||
if err := hook(id, conn); err != nil {
|
||||
log.Errorf("Error executing listener close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// WrapPacketConn on iOS just returns the original connection since iOS handles its own networking
|
||||
func WrapPacketConn(conn *net.UDPConn) *net.UDPConn {
|
||||
return conn
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// ControlPlaneMark is the fwmark value used to mark packets that should not be routed through the NetBird interface to
|
||||
// avoid routing loops.
|
||||
// This includes all control plane traffic (mgmt, signal, flows), relay, ICE/stun/turn and everything that is emitted by the wireguard socket.
|
||||
// It doesn't collide with the other marks, as the others are used for data plane traffic only.
|
||||
ControlPlaneMark = 0x1BD00
|
||||
|
||||
// Data plane marks (0x1BD10 - 0x1BDFF)
|
||||
|
||||
// DataPlaneMarkLower is the lowest value for the data plane range
|
||||
DataPlaneMarkLower = 0x1BD10
|
||||
// DataPlaneMarkUpper is the highest value for the data plane range
|
||||
DataPlaneMarkUpper = 0x1BDFF
|
||||
|
||||
// DataPlaneMarkIn is the mark for inbound data plane traffic.
|
||||
DataPlaneMarkIn = 0x1BD10
|
||||
|
||||
// DataPlaneMarkOut is the mark for outbound data plane traffic.
|
||||
DataPlaneMarkOut = 0x1BD11
|
||||
|
||||
// PreroutingFwmarkRedirected is applied to packets that are were redirected (input -> forward, e.g. by Docker or Podman) for special handling.
|
||||
PreroutingFwmarkRedirected = 0x1BD20
|
||||
|
||||
// PreroutingFwmarkMasquerade is applied to packets that arrive from the NetBird interface and should be masqueraded.
|
||||
PreroutingFwmarkMasquerade = 0x1BD21
|
||||
|
||||
// PreroutingFwmarkMasqueradeReturn is applied to packets that will leave through the NetBird interface and should be masqueraded.
|
||||
PreroutingFwmarkMasqueradeReturn = 0x1BD22
|
||||
)
|
||||
|
||||
// IsDataPlaneMark determines if a fwmark is in the data plane range (0x1BD10-0x1BDFF)
|
||||
func IsDataPlaneMark(fwmark uint32) bool {
|
||||
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
|
||||
}
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||
type ConnectionID string
|
||||
|
||||
type AddHookFunc func(connID ConnectionID, IP net.IP) error
|
||||
type RemoveHookFunc func(connID ConnectionID) error
|
||||
|
||||
// GenerateConnID generates a unique identifier for each connection.
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
||||
var endIP net.IP
|
||||
addr := network.Addr().AsSlice()
|
||||
mask := net.CIDRMask(network.Bits(), len(addr)*8)
|
||||
|
||||
for i := 0; i < len(addr); i++ {
|
||||
endIP = append(endIP, addr[i]|^mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
ip, ok := netip.AddrFromSlice(resultInt.Bytes())
|
||||
if !ok {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address from network %s", network)
|
||||
}
|
||||
|
||||
return ip.Unmap(), nil
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||
func SetSocketMark(conn syscall.Conn) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysconn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get raw conn: %w", err)
|
||||
}
|
||||
|
||||
return setRawSocketMark(sysconn)
|
||||
}
|
||||
|
||||
// SetSocketOpt sets the SO_MARK option on the given file descriptor
|
||||
func SetSocketOpt(fd int) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return setSocketOptInt(fd)
|
||||
}
|
||||
|
||||
func setRawSocketMark(conn syscall.RawConn) error {
|
||||
var setErr error
|
||||
|
||||
err := conn.Control(func(fd uintptr) {
|
||||
if !AdvancedRouting() {
|
||||
return
|
||||
}
|
||||
setErr = setSocketOptInt(int(fd))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
|
||||
if setErr != nil {
|
||||
return fmt.Errorf("set SO_MARK: %w", setErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setSocketOptInt(fd int) error {
|
||||
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, ControlPlaneMark)
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
fromEnd int
|
||||
expected string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 /24 network - last IP (fromEnd=0)",
|
||||
network: "192.168.1.0/24",
|
||||
fromEnd: 0,
|
||||
expected: "192.168.1.255",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /24 network - fromEnd=1",
|
||||
network: "192.168.1.0/24",
|
||||
fromEnd: 1,
|
||||
expected: "192.168.1.254",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /24 network - fromEnd=5",
|
||||
network: "192.168.1.0/24",
|
||||
fromEnd: 5,
|
||||
expected: "192.168.1.250",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /16 network - last IP",
|
||||
network: "10.0.0.0/16",
|
||||
fromEnd: 0,
|
||||
expected: "10.0.255.255",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /16 network - fromEnd=256",
|
||||
network: "10.0.0.0/16",
|
||||
fromEnd: 256,
|
||||
expected: "10.0.254.255",
|
||||
},
|
||||
{
|
||||
name: "IPv4 /32 network - single host",
|
||||
network: "192.168.1.100/32",
|
||||
fromEnd: 0,
|
||||
expected: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "IPv6 /64 network - last IP",
|
||||
network: "2001:db8::/64",
|
||||
fromEnd: 0,
|
||||
expected: "2001:db8::ffff:ffff:ffff:ffff",
|
||||
},
|
||||
{
|
||||
name: "IPv6 /64 network - fromEnd=1",
|
||||
network: "2001:db8::/64",
|
||||
fromEnd: 1,
|
||||
expected: "2001:db8::ffff:ffff:ffff:fffe",
|
||||
},
|
||||
{
|
||||
name: "IPv6 /128 network - single host",
|
||||
network: "2001:db8::1/128",
|
||||
fromEnd: 0,
|
||||
expected: "2001:db8::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
network, err := netip.ParsePrefix(tt.network)
|
||||
require.NoError(t, err, "Failed to parse network prefix")
|
||||
|
||||
result, err := GetLastIPFromNetwork(network, tt.fromEnd)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedIP, err := netip.ParseAddr(tt.expected)
|
||||
require.NoError(t, err, "Failed to parse expected IP")
|
||||
|
||||
assert.Equal(t, expectedIP, result, "IP mismatch for network %s with fromEnd=%d", tt.network, tt.fromEnd)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
var (
|
||||
androidProtectSocketLock sync.Mutex
|
||||
androidProtectSocket func(fd int32) bool
|
||||
)
|
||||
|
||||
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
|
||||
androidProtectSocketLock.Lock()
|
||||
androidProtectSocket = fn
|
||||
androidProtectSocketLock.Unlock()
|
||||
}
|
||||
|
||||
// ControlProtectSocket is a Control function that sets the fwmark on the socket
|
||||
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
|
||||
if netstack.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
var aErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
androidProtectSocketLock.Lock()
|
||||
defer androidProtectSocketLock.Unlock()
|
||||
|
||||
if androidProtectSocket == nil {
|
||||
aErr = fmt.Errorf("socket protection function not set")
|
||||
return
|
||||
}
|
||||
|
||||
if !androidProtectSocket(int32(fd)) {
|
||||
aErr = fmt.Errorf("failed to protect socket via Android")
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return aErr
|
||||
}
|
||||
Reference in New Issue
Block a user