mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
[client] Eliminate UDP proxy in user-space mode (#2712)
In the case of user space WireGuard mode, use in-memory proxy between the TURN/Relay connection and the WireGuard Bind. We keep the UDP proxy and eBPF proxy for kernel mode. The key change is the new wgproxy/bind and the iface/bind/ice_bind changes. Everything else is just to fulfill the dependencies.
This commit is contained in:
137
client/iface/wgproxy/bind/proxy.go
Normal file
137
client/iface/wgproxy/bind/proxy.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
)
|
||||
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
|
||||
wgAddr *net.UDPAddr
|
||||
wgEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// AddTurnConn adds a new connection to the bind.
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.wgAddr = addr
|
||||
p.wgEndpoint = addrToEndpoint(addr)
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return err
|
||||
|
||||
}
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgAddr
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
// Start the proxy only once
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
||||
|
||||
return p.remoteConn.Close()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("failed to close remote conn: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
||||
32
client/iface/wgproxy/ebpf/portlookup.go
Normal file
32
client/iface/wgproxy/ebpf/portlookup.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
portRangeStart = 3128
|
||||
portRangeEnd = 3228
|
||||
)
|
||||
|
||||
type portLookup struct {
|
||||
}
|
||||
|
||||
func (pl portLookup) searchFreePort() (int, error) {
|
||||
for i := portRangeStart; i <= portRangeEnd; i++ {
|
||||
if pl.tryToBind(i) == nil {
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("failed to bind free port for eBPF proxy")
|
||||
}
|
||||
|
||||
func (pl portLookup) tryToBind(port int) error {
|
||||
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = l.Close()
|
||||
return nil
|
||||
}
|
||||
42
client/iface/wgproxy/ebpf/portlookup_test.go
Normal file
42
client/iface/wgproxy/ebpf/portlookup_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_portLookup_searchFreePort(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
_, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_portLookup_on_allocated(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
|
||||
allocatedPort, err := allocatePort(portRangeStart)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer allocatedPort.Close()
|
||||
|
||||
fp, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if fp != (portRangeStart + 1) {
|
||||
t.Errorf("invalid free port, expected: %d, got: %d", portRangeStart+1, fp)
|
||||
}
|
||||
}
|
||||
|
||||
func allocatePort(port int) (net.PacketConn, error) {
|
||||
c, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
283
client/iface/wgproxy/ebpf/proxy.go
Normal file
283
client/iface/wgproxy/ebpf/proxy.go
Normal file
@@ -0,0 +1,283 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
turnConnStore map[uint16]net.Conn
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||
log.Debugf("instantiate ebpf proxy")
|
||||
wgProxy := &WGEBPFProxy{
|
||||
localWGListenPort: wgPort,
|
||||
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
||||
turnConnStore: make(map[uint16]net.Conn),
|
||||
}
|
||||
return wgProxy
|
||||
}
|
||||
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = p.prepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
if cErr := p.Free(); cErr != nil {
|
||||
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
Port: int(wgEndpointPort),
|
||||
}
|
||||
return wgEndpoint, nil
|
||||
}
|
||||
|
||||
// Free resources except the remoteConns will be keep open.
|
||||
func (p *WGEBPFProxy) Free() error {
|
||||
log.Debugf("free up ebpf wg proxy")
|
||||
if p.ctx != nil && p.ctx.Err() != nil {
|
||||
//nolint
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ctxCancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if p.conn != nil {
|
||||
if err := p.conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.ebpfManager.FreeWGProxy(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
if err := p.readAndForwardPacket(buf); err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to proxy packet to remote conn: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
|
||||
}
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
if p.ctx.Err() == nil {
|
||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(buf[:n]); err != nil {
|
||||
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
|
||||
np, err := p.nextFreePort()
|
||||
if err != nil {
|
||||
return np, err
|
||||
}
|
||||
p.turnConnStore[np] = turnConn
|
||||
return np, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
|
||||
_, ok := p.turnConnStore[turnConnID]
|
||||
if ok {
|
||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||
}
|
||||
delete(p.turnConnStore, turnConnID)
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||
if len(p.turnConnStore) == 65535 {
|
||||
return 0, fmt.Errorf("reached maximum turn connection numbers")
|
||||
}
|
||||
generatePort:
|
||||
if p.lastUsedPort == 65535 {
|
||||
p.lastUsedPort = 1
|
||||
} else {
|
||||
p.lastUsedPort++
|
||||
}
|
||||
|
||||
if _, ok := p.turnConnStore[p.lastUsedPort]; ok {
|
||||
goto generatePort
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localhost,
|
||||
SrcIP: localhost,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
56
client/iface/wgproxy/ebpf/proxy_test.go
Normal file
56
client/iface/wgproxy/ebpf/proxy_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
if p != 1 {
|
||||
t.Errorf("invalid initial port: %d", wgProxy.lastUsedPort)
|
||||
}
|
||||
|
||||
numOfConns := 10
|
||||
for i := 0; i < numOfConns; i++ {
|
||||
p, _ = wgProxy.storeTurnConn(nil)
|
||||
}
|
||||
if p != uint16(numOfConns)+1 {
|
||||
t.Errorf("invalid last used port: %d, expected: %d", p, numOfConns+1)
|
||||
}
|
||||
if len(wgProxy.turnConnStore) != numOfConns+1 {
|
||||
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), numOfConns+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
wgProxy.lastUsedPort = 65535
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
|
||||
if len(wgProxy.turnConnStore) != 2 {
|
||||
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), 2)
|
||||
}
|
||||
|
||||
if p != 2 {
|
||||
t.Errorf("invalid last used port: %d, expected: %d", p, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
for i := 0; i < 65535; i++ {
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
}
|
||||
|
||||
_, err := wgProxy.storeTurnConn(nil)
|
||||
if err == nil {
|
||||
t.Errorf("invalid turn conn store calculation")
|
||||
}
|
||||
}
|
||||
126
client/iface/wgproxy/ebpf/wrapper.go
Normal file
126
client/iface/wgproxy/ebpf/wrapper.go
Normal file
@@ -0,0 +1,126 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
47
client/iface/wgproxy/factory_kernel.go
Normal file
47
client/iface/wgproxy/factory_kernel.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
|
||||
ebpfProxy *ebpf.WGEBPFProxy
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int) *KernelFactory {
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
}
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
}
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
if w.ebpfProxy == nil {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||
}
|
||||
|
||||
return &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: w.ebpfProxy,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
26
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
26
client/iface/wgproxy/factory_kernel_freebsd.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// KernelFactory todo: check eBPF support on FreeBSD
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int) *KernelFactory {
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
27
client/iface/wgproxy/factory_usp.go
Normal file
27
client/iface/wgproxy/factory_usp.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
type USPFactory struct {
|
||||
bind *bind.ICEBind
|
||||
}
|
||||
|
||||
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
||||
f := &USPFactory{
|
||||
bind: iceBind,
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *USPFactory) GetProxy() Proxy {
|
||||
return &proxyBind.ProxyBind{
|
||||
Bind: w.bind,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
15
client/iface/wgproxy/proxy.go
Normal file
15
client/iface/wgproxy/proxy.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
|
||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
CloseConn() error
|
||||
}
|
||||
56
client/iface/wgproxy/proxy_linux_test.go
Normal file
56
client/iface/wgproxy/proxy_linux_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
)
|
||||
|
||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
t.Skip("Skipping test as it requires root privileges")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "ebpf proxy",
|
||||
proxy: &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
128
client/iface/wgproxy/proxy_test.go
Normal file
128
client/iface/wgproxy/proxy_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newMockConn() *mocConn {
|
||||
return &mocConn{
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) Read(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Write(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Close() error {
|
||||
if m.closed == true {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
close(m.closeChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mocConn) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{
|
||||
IP: net.ParseIP("172.16.254.1"),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: udpProxy.NewWGUDPProxy(51830),
|
||||
},
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
}
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
195
client/iface/wgproxy/udp/proxy.go
Normal file
195
client/iface/wgproxy/udp/proxy.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package udp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// WGUDPProxy proxies
|
||||
type WGUDPProxy struct {
|
||||
localWGListenPort int
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
|
||||
func NewWGUDPProxy(wgPort int) *WGUDPProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
return endpointUdpAddr
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUDPProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
func (p *WGUDPProxy) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUDPProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
// prevent double close
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for ctx.Err() == nil {
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("failed to write to remote conn: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user