Compare commits

...

27 Commits

Author SHA1 Message Date
Zoltán Papp
2e6a44af1f Merge branch 'main' into fix/pkg-loss 2025-02-24 10:44:32 +01:00
Zoltán Papp
559d347588 Revert async WireGuard handshake 2025-02-21 16:19:28 +01:00
Zoltán Papp
06d71257b4 Build rawsocket code on linux only. 2025-02-21 16:00:31 +01:00
Zoltán Papp
62d10496ee Merge branch 'main' into fix/pkg-loss 2025-02-21 15:52:51 +01:00
Zoltán Papp
651e88d611 Eliminate code duplication 2025-02-21 14:59:42 +01:00
Zoltán Papp
d496d21693 Apply same pausedCond logic on all implementation 2025-02-21 14:55:56 +01:00
Zoltan Papp
648b4cdf72 Update client/iface/wgproxy/udp/proxy.go
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-02-21 14:50:29 +01:00
Zoltán Papp
f0020ad4ce Fallback to package loss solution if the raw socket does not work. 2025-02-18 13:58:20 +01:00
Zoltán Papp
1eacff250e Remove WireGuard kernel code from FreeBSD 2025-02-18 13:47:42 +01:00
Zoltán Papp
1f83ba4563 Ignore err in tests 2025-02-17 22:06:04 +01:00
Zoltán Papp
3d80a25b4d Fix possible blocker if the bind will be closed earlier then proxy 2025-02-17 22:03:15 +01:00
Zoltán Papp
1963644c99 Add close test for all implementation 2025-02-17 21:47:34 +01:00
Zoltán Papp
360c7134f7 Add unit test 2025-02-17 21:26:37 +01:00
Zoltán Papp
775b4feb7e Fix close operation 2025-02-17 20:17:47 +01:00
Zoltán Papp
aca443bdec Build UDP proxy on Linux only 2025-02-17 19:26:40 +01:00
Zoltán Papp
335866ac60 Close unused rawsocket 2025-02-17 15:56:48 +01:00
Zoltán Papp
082452eb5f Add info log line 2025-02-17 15:50:49 +01:00
Zoltan Papp
b17c1d96a5 [client] Improve WireGuard handshake success rate (#3092)
The controller peer sends WireGuard handshake requests only
2025-02-17 15:50:15 +01:00
Zoltán Papp
2d5b5f59c2 Remove log line 2025-02-17 15:49:04 +01:00
Zoltán Papp
d5042f688f Fix interface changes 2025-02-17 15:47:20 +01:00
Zoltán Papp
4db73a13d7 Implement redirect logic in UDP proxy 2025-02-17 15:47:20 +01:00
Zoltán Papp
06a17f0eee Implement redirect to in eBPF proxy 2025-02-17 15:47:20 +01:00
Zoltán Papp
1f088b7e69 Extend the proxy interface with RedirectTo function and implement it in Bind proxy 2025-02-17 15:47:20 +01:00
Zoltan Papp
ffe74365a8 Code cleaning 2025-02-17 15:47:20 +01:00
Zoltán Papp
6a0f6efc18 Always stop timer 2025-02-17 15:47:20 +01:00
Zoltán Papp
bfa6df13c5 The non handshake initiator peer start the handshake after timeout 2025-02-17 15:47:20 +01:00
Zoltán Papp
e9b3b6210d Improve WireGuard handshake success rate
The controller peer sends WireGuard
handshake requests only
2025-02-17 15:47:20 +01:00
20 changed files with 909 additions and 284 deletions

View File

@@ -1,5 +1,16 @@
package bind package bind
import wgConn "golang.zx2c4.com/wireguard/conn" import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
}
}

View File

@@ -1,6 +1,7 @@
package bind package bind
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -38,7 +39,7 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported. // use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct { type ICEBind struct {
*wgConn.StdNetBind *wgConn.StdNetBind
RecvChan chan RecvMessage recvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn FilterFn filterFn FilterFn
@@ -58,7 +59,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
RecvChan: make(chan RecvMessage, 1), recvChan: make(chan RecvMessage, 1),
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn, filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
@@ -155,6 +156,14 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil return nil
} }
func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) {
select {
case <-ctx.Done():
return
case b.recvChan <- msg:
}
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
@@ -264,7 +273,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select { select {
case <-c.closedChan: case <-c.closedChan:
return 0, net.ErrClosed return 0, net.ErrClosed
case msg, ok := <-c.RecvChan: case msg, ok := <-c.recvChan:
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
} }

View File

@@ -0,0 +1,40 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build linux && !android
package iface package iface

View File

@@ -13,40 +13,58 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
) )
type ProxyBind struct { type IceBind interface {
Bind *bind.ICEBind SetEndpoint(addr *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error)
RemoveEndpoint(addr *net.UDPAddr)
Recv(ctx context.Context, msg bind.RecvMessage)
}
wgAddr *net.UDPAddr type ProxyBind struct {
wgEndpoint *bind.Endpoint bind IceBind
// wgEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
} }
func NewProxyBind(bind IceBind) *ProxyBind {
return &ProxyBind{
bind: bind,
pausedCond: sync.NewCond(&sync.Mutex{}),
}
}
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) fakeAddr, err := p.bind.SetEndpoint(nbAddr, remoteConn)
if err != nil { if err != nil {
return err return err
} }
p.wgAddr = addr p.wgRelayedEndpoint = addrToEndpoint(fakeAddr)
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return err return err
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
} }
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
@@ -54,15 +72,20 @@ func (p *ProxyBind) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once // Start the proxy only once
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.L.Unlock()
// todo: review to should be inside the lock scope
p.pausedCond.Signal()
} }
func (p *ProxyBind) Pause() { func (p *ProxyBind) Pause() {
@@ -70,9 +93,19 @@ func (p *ProxyBind) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
} }
func (p *ProxyBind) CloseConn() error { func (p *ProxyBind) CloseConn() error {
@@ -83,6 +116,10 @@ func (p *ProxyBind) CloseConn() error {
} }
func (p *ProxyBind) close() error { func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -93,7 +130,12 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr) p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
p.bind.RemoveEndpoint(bind.EndpointToUDPAddr(*p.wgRelayedEndpoint))
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -119,18 +161,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
msg := bind.RecvMessage{ msg := bind.RecvMessage{
Endpoint: p.wgEndpoint, Endpoint: p.wgCurrentUsed,
Buffer: buf[:n], Buffer: buf[:n],
} }
p.Bind.RecvChan <- msg p.bind.Recv(ctx, msg)
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
} }
} }

View File

@@ -6,9 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"syscall"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -17,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -26,6 +25,10 @@ const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -61,7 +64,7 @@ func (p *WGEBPFProxy) Listen() error {
return err return err
} }
p.rawConn, err = p.prepareSenderRawSocket() p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil { if err != nil {
return err return err
} }
@@ -211,57 +214,17 @@ generatePort:
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)
ipH := &layers.IPv4{ ipH := &layers.IPv4{
DstIP: localhost, DstIP: localHostNetIP,
SrcIP: localhost, SrcIP: endpointAddr.IP,
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
udpH := &layers.UDP{ udpH := &layers.UDP{
SrcPort: layers.UDPPort(port), SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort), DstPort: layers.UDPPort(p.localWGListenPort),
} }
@@ -276,7 +239,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) return fmt.Errorf("serialize layers: %w", err)
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) return fmt.Errorf("write to raw conn: %w", err)
} }
return nil return nil

View File

@@ -15,32 +15,39 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
} }
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr p.wgRelayedEndpointAddr = addr
return err return err
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgRelayedEndpointAddr
} }
func (p *ProxyWrapper) Work() { func (p *ProxyWrapper) Work() {
@@ -48,14 +55,19 @@ func (p *ProxyWrapper) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.L.Unlock()
// todo: review to should be inside the lock scope
p.pausedCond.Signal()
} }
func (p *ProxyWrapper) Pause() { func (p *ProxyWrapper) Pause() {
@@ -64,27 +76,42 @@ func (p *ProxyWrapper) Pause() {
} }
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error { func (p *ProxyWrapper) CloseConn() error {
if e.cancel == nil { if p.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
} }
e.cancel() p.cancel()
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
@@ -93,14 +120,13 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -118,7 +144,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
return 0, ctx.Err() return 0, ctx.Err()
} }
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
} }
return 0, err return 0, err
} }

View File

@@ -36,9 +36,7 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort)
} }
return &ebpf.ProxyWrapper{ return ebpf.NewProxyWrapper(w.ebpfProxy)
WgeBPFProxy: w.ebpfProxy,
}
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -1,29 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{ return proxyBind.NewProxyBind(w.bind)
Bind: w.bind,
}
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -11,5 +11,11 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
/*
RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
and rewrite the src address to the endpoint address.
With this logic can avoid the package loss from relayed connections.
*/
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error CloseConn() error
} }

View File

@@ -3,54 +3,78 @@
package wgproxy package wgproxy
import ( import (
"context" "fmt"
"os" "net"
"testing"
"github.com/netbirdio/netbird/client/iface/bind"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
func TestProxyCloseByRemoteConnEBPF(t *testing.T) { func seedProxies() ([]proxyInstance, error) {
if os.Getenv("GITHUB_ACTIONS") != "true" { pl := make([]proxyInstance, 0)
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831) ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
} }
defer func() { pEbpf := proxyInstance{
if err := ebpfProxy.Free(); err != nil { name: "ebpf kernel proxy",
t.Errorf("failed to free ebpf proxy: %s", err) proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
} }
}() pl = append(pl, pEbpf)
tests := []struct { pUDP := proxyInstance{
name string name: "udp kernel proxy",
proxy Proxy proxy: udp.NewWGUDPProxy(51832),
}{ wgPort: 51832,
{ closeFn: func() error { return nil },
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
} }
pl = append(pl, pUDP)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
iceBind := bind.NewICEBind(nil, nil)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
} }

View File

@@ -0,0 +1,34 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
iceBind := bind.NewICEBind(nil, nil)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,120 +1,142 @@
//go:build linux
package wgproxy package wgproxy
import ( import (
"context" "context"
"io"
"net" "net"
"os"
"runtime"
"testing" "testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func TestMain(m *testing.M) { func init() {
_ = util.InitLog("trace", "console") _ = util.InitLog("debug", "console")
code := m.Run()
os.Exit(code)
} }
type mocConn struct { type proxyInstance struct {
closeChan chan struct{} name string
closed bool proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
} }
func newMockConn() *mocConn { // TestProxyRedirect todo extend the proxies with Bind proxy
return &mocConn{ func TestProxyRedirect(t *testing.T) {
closeChan: make(chan struct{}), tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
} }
} }
func (m *mocConn) Read(b []byte) (n int, err error) { func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
<-m.closeChan t.Helper()
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) { msgHelloFromRelay := []byte("hello from relay")
<-m.closeChan msgRedirected := [][]byte{
return 0, io.EOF []byte("hello 1. to p2p"),
} []byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
func (m *mocConn) Close() error {
if m.closed == true {
return nil
} }
m.closed = true dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
close(m.closeChan) IP: net.IPv4(127, 0, 0, 1),
return nil Port: wgPort})
} if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
} }
}
func (m *mocConn) SetDeadline(t time.Time) error { relayedServer, _ := net.ListenUDP("udp",
panic("implement me") &net.UDPAddr{
} IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
func (m *mocConn) SetReadDeadline(t time.Time) error { relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error { defer func() {
panic("implement me") _ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
} }
func TestProxyCloseByRemoteConn(t *testing.T) { func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tests := []struct { tests, err := seedProxyForProxyCloseByRemoteConn()
name string if err != nil {
proxy Proxy t.Fatalf("error: %v", err)
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830),
},
} }
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() { defer func() {
if err := ebpfProxy.Free(); err != nil { _ = relayedConn.Close()
t.Errorf("failed to free ebpf proxy: %s", err)
}
}() }()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn() err := tt.proxy.AddTurnConn(ctx, tt.endpointAddr, relayedConn)
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }

View File

@@ -0,0 +1,50 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/util/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux && !android
package udp package udp
import ( import (
@@ -20,13 +22,15 @@ type WGUDPProxy struct {
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
srcFakerConn *SrcFaker
sendPkg func(data []byte) (int, error)
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool paused bool
pausedCond *sync.Cond
isStarted bool isStarted bool
} }
@@ -35,6 +39,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
pausedCond: sync.NewCond(&sync.Mutex{}),
} }
return p return p
} }
@@ -54,6 +59,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn p.remoteConn = remoteConn
return err return err
@@ -73,15 +79,24 @@ func (p *WGUDPProxy) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.sendPkg = p.localConn.Write
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToRemote(p.ctx) go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
} }
// Pause pauses the proxy from receiving data from the remote peer // Pause pauses the proxy from receiving data from the remote peer
@@ -90,9 +105,35 @@ func (p *WGUDPProxy) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -104,6 +145,8 @@ func (p *WGUDPProxy) CloseConn() error {
} }
func (p *WGUDPProxy) close() error { func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -111,11 +154,14 @@ func (p *WGUDPProxy) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closed = true
p.cancel() p.cancel()
var result *multierror.Error p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }
@@ -123,6 +169,13 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
@@ -175,14 +228,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
_, err = p.sendPkg(buf[:n])
_, err = p.localConn.Write(buf[:n]) p.pausedCond.L.Unlock()
p.pausedMu.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -0,0 +1,101 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -356,18 +356,27 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
conn.log.Debugf("pause Relayed proxy")
conn.wgProxyRelay.Pause() conn.wgProxyRelay.Pause()
} }
if wgProxy != nil { if wgProxy != nil {
conn.log.Debugf("run ICE proxy")
wgProxy.Work() wgProxy.Work()
} }
conn.log.Infof("configure WireGuard endpoint to: %s", ep.String())
if err = conn.configureWGEndpoint(ep); err != nil { if err = conn.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy) conn.handleConfigurationFailure(err, wgProxy)
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyRelay != nil {
conn.log.Debugf("redirect packages from relayed conn to WireGuard")
conn.wgProxyRelay.RedirectAs(ep)
}
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected) conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
@@ -393,11 +402,12 @@ func (conn *Conn) onICEStateDisconnected() {
// switch back to relay connection // switch back to relay connection
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection") conn.log.Infof("ICE disconnected, set Relay to active connection")
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.wgProxyRelay.Work()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
} else { } else {
@@ -760,6 +770,11 @@ func isController(config ConnConfig) bool {
return config.LocalKey > config.Key return config.LocalKey > config.Key
} }
// isWireGuardInitiator returns true if the local peer is the initiator of the WireGuard connection
func isWireGuardInitiator(config ConnConfig) bool {
return isController(config)
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }

View File

@@ -0,0 +1,87 @@
package peer
import (
"context"
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// fallbackDelay could be const but because of testing it is a var
var fallbackDelay = 5 * time.Second
type endpointUpdater struct {
log *logrus.Entry
wgConfig WgConfig
initiator bool
cancelFunc func()
configUpdateMutex sync.Mutex
}
// configureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr) error {
if e.initiator {
return e.updateWireGuardPeer(addr)
}
// prevent to run new update while cancel the previous update
e.configUpdateMutex.Lock()
if e.cancelFunc != nil {
e.cancelFunc()
}
e.configUpdateMutex.Unlock()
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
go e.scheduleDelayedUpdate(ctx, addr)
return e.updateWireGuardPeer(nil)
}
func (e *endpointUpdater) removeWgPeer() error {
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()
if e.cancelFunc != nil {
e.cancelFunc()
}
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr) {
t := time.NewTimer(fallbackDelay)
defer t.Stop()
select {
case <-ctx.Done():
return
case <-t.C:
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()
if ctx.Err() != nil {
return
}
if err := e.updateWireGuardPeer(addr); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
}
}
func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr) error {
return e.wgConfig.WgInterface.UpdatePeer(
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
e.wgConfig.PreSharedKey,
)
}

View File

@@ -0,0 +1,178 @@
package peer
import (
"net"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
type MockWgInterface struct {
mock.Mock
lastSetAddr *net.UDPAddr
}
func (m *MockWgInterface) GetStats(peerKey string) (configurer.WGStats, error) {
panic("implement me")
}
func (m *MockWgInterface) GetProxy() wgproxy.Proxy {
panic("implement me")
}
func (m *MockWgInterface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
args := m.Called(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
m.lastSetAddr = endpoint
return args.Error(0)
}
func (m *MockWgInterface) RemovePeer(publicKey string) error {
args := m.Called(publicKey)
return args.Error(0)
}
func Test_endpointUpdater_initiator(t *testing.T) {
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: true,
}
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}
mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
(*wgtypes.Key)(nil),
).Return(nil)
if err := e.configureWGEndpoint(addr); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil))
}
func Test_endpointUpdater_nonInitiator(t *testing.T) {
fallbackDelay = 1 * time.Second
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: false,
}
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}
mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
(*net.UDPAddr)(nil),
(*wgtypes.Key)(nil),
).Return(nil)
mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
(*wgtypes.Key)(nil),
).Return(nil)
err := e.configureWGEndpoint(addr)
if err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil))
time.Sleep(fallbackDelay + time.Second)
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil))
}
func Test_endpointUpdater_overRule(t *testing.T) {
fallbackDelay = 1 * time.Second
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: false,
}
addr1 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1000,
}
addr2 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1001,
}
mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
(*net.UDPAddr)(nil),
(*wgtypes.Key)(nil),
).Return(nil)
mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr2,
(*wgtypes.Key)(nil),
).Return(nil)
if err := e.configureWGEndpoint(addr1); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil))
if err := e.configureWGEndpoint(addr2); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
time.Sleep(fallbackDelay + time.Second)
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr2, (*wgtypes.Key)(nil))
if mockWgInterface.lastSetAddr != addr2 {
t.Fatalf("lastSetAddr is not equal to addr2")
}
}