Feature/exit nodes - Windows and macOS support (#1726)

This commit is contained in:
Viktor Liu
2024-04-03 11:11:46 +02:00
committed by GitHub
parent 9af532fe71
commit 7938295190
41 changed files with 2254 additions and 965 deletions

View File

@@ -1,11 +1,10 @@
//go:build !android
package grpc
import (
"context"
"net"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -13,6 +12,11 @@ import (
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return nbnet.NewDialer().DialContext(ctx, "tcp", addr)
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, err
}
return conn, nil
})
}

View File

@@ -1,9 +0,0 @@
//go:build !linux || android
package grpc
import "google.golang.org/grpc"
func WithCustomDialer() grpc.DialOption {
return grpc.EmptyDialOption{}
}

64
util/net/dialer.go Normal file
View File

@@ -0,0 +1,64 @@
package net
import (
"fmt"
"net"
log "github.com/sirupsen/logrus"
)
// Dialer extends the standard net.Dialer with the ability to execute hooks before
// and after connections. This can be used to bypass the VPN for connections using this dialer.
type Dialer struct {
*net.Dialer
}
// NewDialer returns a customized net.Dialer with overridden Control method
func NewDialer() *Dialer {
dialer := &Dialer{
Dialer: &net.Dialer{},
}
dialer.init()
return dialer
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to closeConn connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}
return tcpConn, nil
}

View File

@@ -1,19 +1,123 @@
//go:build !linux || android
//go:build !android && !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
)
func NewDialer() *net.Dialer {
return &net.Dialer{}
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
var (
dialerDialHooksMutex sync.RWMutex
dialerDialHooks []DialerDialHookFunc
dialerCloseHooksMutex sync.RWMutex
dialerCloseHooks []DialerCloseHookFunc
)
// AddDialerHook allows adding a new hook to be executed before dialing.
func AddDialerHook(hook DialerDialHookFunc) {
dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock()
dialerDialHooks = append(dialerDialHooks, hook)
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
return net.DialUDP(network, laddr, raddr)
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
func AddDialerCloseHook(hook DialerCloseHookFunc) {
dialerCloseHooksMutex.Lock()
defer dialerCloseHooksMutex.Unlock()
dialerCloseHooks = append(dialerCloseHooks, hook)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
return net.DialTCP(network, laddr, raddr)
// RemoveDialerHook removes all dialer hooks.
func RemoveDialerHooks() {
dialerDialHooksMutex.Lock()
defer dialerDialHooksMutex.Unlock()
dialerDialHooks = nil
dialerCloseHooksMutex.Lock()
defer dialerCloseHooksMutex.Unlock()
dialerCloseHooks = nil
}
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
var resolver *net.Resolver
if d.Resolver != nil {
resolver = d.Resolver
}
connID := GenerateConnID()
if dialerDialHooks != nil {
if err := calliDialerHooks(ctx, connID, address, resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err)
}
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("dial: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks
return &Conn{Conn: conn, ID: connID}, nil
}
// Dial wraps the net.Dialer's Dial method to use the custom connection
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
func (c *Conn) Close() error {
err := c.Conn.Close()
dialerCloseHooksMutex.RLock()
defer dialerCloseHooksMutex.RUnlock()
for _, hook := range dialerCloseHooks {
if err := hook(c.ID, &c.Conn); err != nil {
log.Errorf("Error executing dialer close hook: %v", err)
}
}
return err
}
func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("split host and port: %w", err)
}
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
var result *multierror.Error
dialerDialHooksMutex.RLock()
defer dialerDialHooksMutex.RUnlock()
for _, hook := range dialerDialHooks {
if err := hook(ctx, connID, ips); err != nil {
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
}
}
return result.ErrorOrNil()
}

View File

@@ -2,59 +2,11 @@
package net
import (
"context"
"fmt"
"net"
"syscall"
import "syscall"
log "github.com/sirupsen/logrus"
)
func NewDialer() *net.Dialer {
return &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
return SetRawSocketMark(c)
},
// init configures the net.Dialer Control function to set the fwmark on the socket
func (d *Dialer) init() {
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c)
}
}
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
udpConn, ok := conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got different type")
}
return udpConn, nil
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.DialContext(context.Background(), network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got different type")
}
return tcpConn, nil
}

View File

@@ -0,0 +1,6 @@
//go:build !linux || android
package net
func (d *Dialer) init() {
}

21
util/net/listener.go Normal file
View File

@@ -0,0 +1,21 @@
package net
import (
"net"
)
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
type ListenerConfig struct {
*net.ListenConfig
}
// NewListener creates a new ListenerConfig instance.
func NewListener() *ListenerConfig {
listener := &ListenerConfig{
ListenConfig: &net.ListenConfig{},
}
listener.init()
return listener
}

View File

@@ -1,13 +1,154 @@
//go:build !linux || android
//go:build !android && !ios
package net
import "net"
import (
"context"
"fmt"
"net"
"sync"
func NewListener() *net.ListenConfig {
return &net.ListenConfig{}
log "github.com/sirupsen/logrus"
)
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
var (
listenerWriteHooksMutex sync.RWMutex
listenerWriteHooks []ListenerWriteHookFunc
listenerCloseHooksMutex sync.RWMutex
listenerCloseHooks []ListenerCloseHookFunc
)
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock()
listenerWriteHooks = append(listenerWriteHooks, hook)
}
func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, locAddr)
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = append(listenerCloseHooks, hook)
}
// RemoveListenerHooks removes all dialer hooks.
func RemoveListenerHooks() {
listenerWriteHooksMutex.Lock()
defer listenerWriteHooksMutex.Unlock()
listenerWriteHooks = nil
listenerCloseHooksMutex.Lock()
defer listenerCloseHooksMutex.Unlock()
listenerCloseHooks = nil
}
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
}
connID := GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
}
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
type PacketConn struct {
net.PacketConn
ID ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
callWriteHooks(c.ID, c.seenAddrs, b, addr)
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
func (c *PacketConn) Close() error {
c.seenAddrs = &sync.Map{}
return closeConn(c.ID, c.PacketConn)
}
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
type UDPConn struct {
*net.UDPConn
ID ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
callWriteHooks(c.ID, c.seenAddrs, b, addr)
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
func (c *UDPConn) Close() error {
c.seenAddrs = &sync.Map{}
return closeConn(c.ID, c.UDPConn)
}
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
ipStr, _, splitErr := net.SplitHostPort(addr.String())
if splitErr != nil {
log.Errorf("Error splitting IP address and port: %v", splitErr)
return
}
ip, err := net.ResolveIPAddr("ip", ipStr)
if err != nil {
log.Errorf("Error resolving IP address: %v", err)
return
}
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
func() {
listenerWriteHooksMutex.RLock()
defer listenerWriteHooksMutex.RUnlock()
for _, hook := range listenerWriteHooks {
if err := hook(id, ip, b); err != nil {
log.Errorf("Error executing listener write hook: %v", err)
}
}
}()
}
}
func closeConn(id ConnectionID, conn net.PacketConn) error {
err := conn.Close()
listenerCloseHooksMutex.RLock()
defer listenerCloseHooksMutex.RUnlock()
for _, hook := range listenerCloseHooks {
if err := hook(id, conn); err != nil {
log.Errorf("Error executing listener close hook: %v", err)
}
}
return err
}
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
udpConn, err := net.ListenUDP(network, laddr)
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
connID := GenerateConnID()
return &UDPConn{UDPConn: udpConn, ID: connID, seenAddrs: &sync.Map{}}, nil
}

View File

@@ -3,28 +3,12 @@
package net
import (
"context"
"fmt"
"net"
"syscall"
)
func NewListener() *net.ListenConfig {
return &net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return SetRawSocketMark(c)
},
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
func (l *ListenerConfig) init() {
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
return SetRawSocketMark(c)
}
}
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err)
}
udpConn, ok := pc.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("packetConn is not a *net.UDPConn")
}
return udpConn, nil
}

View File

@@ -0,0 +1,11 @@
//go:build android || ios
package net
import (
"net"
)
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP(network, laddr)
}

View File

@@ -0,0 +1,6 @@
//go:build !linux || android
package net
func (l *ListenerConfig) init() {
}

View File

@@ -1,6 +1,17 @@
package net
import "github.com/google/uuid"
const (
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
NetbirdFwmark = 0x1BD00
)
// ConnectionID provides a globally unique identifier for network connections.
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string
// GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}