mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
[client] Eliminate UDP proxy in user-space mode (#2712)
In the case of user space WireGuard mode, use in-memory proxy between the TURN/Relay connection and the WireGuard Bind. We keep the UDP proxy and eBPF proxy for kernel mode. The key change is the new wgproxy/bind and the iface/bind/ice_bind changes. Everything else is just to fulfill the dependencies.
This commit is contained in:
@@ -267,7 +267,17 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun230%d", n),
|
||||
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -345,7 +355,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: "100.66.100.1/32",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Errorf("build interface wireguard: %v", err)
|
||||
return
|
||||
@@ -803,7 +821,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
}
|
||||
|
||||
privKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: "utun2301",
|
||||
Address: "100.66.100.2/24",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: privKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -35,7 +35,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -141,8 +140,7 @@ type Engine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgInterface iface.IWGIface
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
wgInterface iface.IWGIface
|
||||
|
||||
udpMux *bind.UniversalUDPMuxDefault
|
||||
|
||||
@@ -299,9 +297,6 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.wgInterface = wgIface
|
||||
|
||||
userspace := e.wgInterface.IsUserspaceBind()
|
||||
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
if e.config.RosenpassPermissive {
|
||||
@@ -966,7 +961,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
||||
},
|
||||
}
|
||||
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
|
||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1117,12 +1112,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
}
|
||||
|
||||
func (e *Engine) close() {
|
||||
if e.wgProxyFactory != nil {
|
||||
if err := e.wgProxyFactory.Free(); err != nil {
|
||||
log.Errorf("failed closing ebpf proxy: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
@@ -1167,21 +1156,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||
}
|
||||
|
||||
var mArgs *device.MobileIFaceArguments
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: e.config.WgIfaceName,
|
||||
Address: e.config.WgAddr,
|
||||
WGPort: e.config.WgPort,
|
||||
WGPrivKey: e.config.WgPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
opts.MobileArgs = &device.MobileIFaceArguments{
|
||||
TunAdapter: e.mobileDep.TunAdapter,
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
case "ios":
|
||||
mArgs = &device.MobileIFaceArguments{
|
||||
opts.MobileArgs = &device.MobileIFaceArguments{
|
||||
TunFd: int(e.mobileDep.FileDescriptor),
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
|
||||
return iface.NewWGIFace(opts)
|
||||
}
|
||||
|
||||
func (e *Engine) wgInterfaceCreate() (err error) {
|
||||
|
||||
@@ -602,7 +602,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: wgIfaceName,
|
||||
Address: wgAddr,
|
||||
WGPort: engine.config.WgPort,
|
||||
WGPrivKey: key.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(opts)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
input := struct {
|
||||
inputSerial uint64
|
||||
@@ -774,7 +783,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: wgIfaceName,
|
||||
Address: wgAddr,
|
||||
WGPort: 33100,
|
||||
WGPrivKey: key.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
engine.wgInterface, err = iface.NewWGIFace(opts)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -81,11 +81,10 @@ type Conn struct {
|
||||
ctxCancel context.CancelFunc
|
||||
config ConnConfig
|
||||
statusRecorder *Status
|
||||
wgProxyFactory *wgproxy.Factory
|
||||
signaler *Signaler
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
relayManager *relayClient.Manager
|
||||
allowedIPsIP string
|
||||
allowedIP net.IP
|
||||
allowedNet string
|
||||
handshaker *Handshaker
|
||||
|
||||
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
|
||||
@@ -116,8 +115,8 @@ type Conn struct {
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
// To establish a connection run Conn.Open
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) {
|
||||
_, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) {
|
||||
allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||
if err != nil {
|
||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||
return nil, err
|
||||
@@ -127,19 +126,17 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
log: connLog,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
wgProxyFactory: wgProxyFactory,
|
||||
signaler: signaler,
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
relayManager: relayManager,
|
||||
allowedIPsIP: allowedIPsIP.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
|
||||
log: connLog,
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
signaler: signaler,
|
||||
relayManager: relayManager,
|
||||
allowedIP: allowedIP,
|
||||
allowedNet: allowedNet.String(),
|
||||
statusRelay: NewAtomicConnStatus(),
|
||||
statusICE: NewAtomicConnStatus(),
|
||||
iCEDisconnected: make(chan bool, 1),
|
||||
relayDisconnected: make(chan bool, 1),
|
||||
}
|
||||
@@ -692,7 +689,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
|
||||
if conn.onConnected != nil {
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr)
|
||||
conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -783,8 +780,13 @@ func (conn *Conn) freeUpConnID() {
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.log.Debugf("setup proxied WireGuard connection")
|
||||
wgProxy := conn.wgProxyFactory.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: conn.allowedIP,
|
||||
Port: conn.config.WgConfig.WgListenPort,
|
||||
}
|
||||
|
||||
wgProxy := conn.config.WgConfig.WgInterface.GetProxy()
|
||||
if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil {
|
||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -44,11 +43,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_GetKey(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -59,11 +54,7 @@ func TestConn_GetKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -96,11 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -132,11 +119,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
func TestConn_Status(t *testing.T) {
|
||||
wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
|
||||
defer func() {
|
||||
_ = wgProxyFactory.Free()
|
||||
}()
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
|
||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -407,7 +407,15 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun43%d", n),
|
||||
Address: "100.65.65.2/24",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
|
||||
@@ -61,7 +61,14 @@ func TestAddRemoveRoutes(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
||||
Address: "100.65.75.2/24",
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -213,7 +220,15 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: fmt.Sprintf("utun53%d", n),
|
||||
Address: "100.65.75.2/24",
|
||||
WGPort: 33100,
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WGIface interface")
|
||||
defer wgInterface.Close()
|
||||
|
||||
@@ -345,7 +360,15 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
||||
newNet, err := stdnet.NewNet()
|
||||
require.NoError(t, err)
|
||||
|
||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
|
||||
opts := iface.WGIFaceOpts{
|
||||
IFaceName: interfaceName,
|
||||
Address: ipAddressCIDR,
|
||||
WGPrivKey: peerPrivateKey.String(),
|
||||
WGPort: listenPort,
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: newNet,
|
||||
}
|
||||
wgInterface, err := iface.NewWGIFace(opts)
|
||||
require.NoError(t, err, "should create testing WireGuard interface")
|
||||
|
||||
err = wgInterface.Create()
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
portRangeStart = 3128
|
||||
portRangeEnd = 3228
|
||||
)
|
||||
|
||||
type portLookup struct {
|
||||
}
|
||||
|
||||
func (pl portLookup) searchFreePort() (int, error) {
|
||||
for i := portRangeStart; i <= portRangeEnd; i++ {
|
||||
if pl.tryToBind(i) == nil {
|
||||
return i, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("failed to bind free port for eBPF proxy")
|
||||
}
|
||||
|
||||
func (pl portLookup) tryToBind(port int) error {
|
||||
l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = l.Close()
|
||||
return nil
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_portLookup_searchFreePort(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
_, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_portLookup_on_allocated(t *testing.T) {
|
||||
pl := portLookup{}
|
||||
|
||||
allocatedPort, err := allocatePort(portRangeStart)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer allocatedPort.Close()
|
||||
|
||||
fp, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if fp != (portRangeStart + 1) {
|
||||
t.Errorf("invalid free port, expected: %d, got: %d", portRangeStart+1, fp)
|
||||
}
|
||||
}
|
||||
|
||||
func allocatePort(port int) (net.PacketConn, error) {
|
||||
c, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
@@ -1,283 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
|
||||
ebpfManager ebpfMgr.Manager
|
||||
turnConnStore map[uint16]net.Conn
|
||||
turnConnMutex sync.Mutex
|
||||
|
||||
lastUsedPort uint16
|
||||
rawConn net.PacketConn
|
||||
conn transport.UDPConn
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewWGEBPFProxy create new WGEBPFProxy instance
|
||||
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||
log.Debugf("instantiate ebpf proxy")
|
||||
wgProxy := &WGEBPFProxy{
|
||||
localWGListenPort: wgPort,
|
||||
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
||||
turnConnStore: make(map[uint16]net.Conn),
|
||||
}
|
||||
return wgProxy
|
||||
}
|
||||
|
||||
// Listen load ebpf program and listen the proxy
|
||||
func (p *WGEBPFProxy) Listen() error {
|
||||
pl := portLookup{}
|
||||
wgPorxyPort, err := pl.searchFreePort()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = p.prepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addr := net.UDPAddr{
|
||||
Port: wgPorxyPort,
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
}
|
||||
|
||||
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
||||
if err != nil {
|
||||
if cErr := p.Free(); cErr != nil {
|
||||
log.Errorf("Failed to close the wgproxy: %s", cErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
p.conn = conn
|
||||
|
||||
go p.proxyToRemote()
|
||||
log.Infof("local wg proxy listening on: %d", wgPorxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTurnConn add new turn connection for the proxy
|
||||
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
|
||||
wgEndpointPort, err := p.storeTurnConn(turnConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
|
||||
|
||||
wgEndpoint := &net.UDPAddr{
|
||||
IP: net.ParseIP(loopbackAddr),
|
||||
Port: int(wgEndpointPort),
|
||||
}
|
||||
return wgEndpoint, nil
|
||||
}
|
||||
|
||||
// Free resources except the remoteConns will be keep open.
|
||||
func (p *WGEBPFProxy) Free() error {
|
||||
log.Debugf("free up ebpf wg proxy")
|
||||
if p.ctx != nil && p.ctx.Err() != nil {
|
||||
//nolint
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ctxCancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if p.conn != nil { // p.conn will be nil if we have failed to listen
|
||||
if err := p.conn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.ebpfManager.FreeWGProxy(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
|
||||
if err := p.rawConn.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
|
||||
// From this go routine has only one instance.
|
||||
func (p *WGEBPFProxy) proxyToRemote() {
|
||||
buf := make([]byte, 1500)
|
||||
for p.ctx.Err() == nil {
|
||||
if err := p.readAndForwardPacket(buf); err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to proxy packet to remote conn: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
|
||||
n, addr, err := p.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
|
||||
}
|
||||
|
||||
p.turnConnMutex.Lock()
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
if p.ctx.Err() == nil {
|
||||
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(buf[:n]); err != nil {
|
||||
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
|
||||
np, err := p.nextFreePort()
|
||||
if err != nil {
|
||||
return np, err
|
||||
}
|
||||
p.turnConnStore[np] = turnConn
|
||||
return np, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
|
||||
p.turnConnMutex.Lock()
|
||||
defer p.turnConnMutex.Unlock()
|
||||
|
||||
_, ok := p.turnConnStore[turnConnID]
|
||||
if ok {
|
||||
log.Debugf("remove turn conn from store by port: %d", turnConnID)
|
||||
}
|
||||
delete(p.turnConnStore, turnConnID)
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
|
||||
if len(p.turnConnStore) == 65535 {
|
||||
return 0, fmt.Errorf("reached maximum turn connection numbers")
|
||||
}
|
||||
generatePort:
|
||||
if p.lastUsedPort == 65535 {
|
||||
p.lastUsedPort = 1
|
||||
} else {
|
||||
p.lastUsedPort++
|
||||
}
|
||||
|
||||
if _, ok := p.turnConnStore[p.lastUsedPort]; ok {
|
||||
goto generatePort
|
||||
}
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localhost,
|
||||
SrcIP: localhost,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
layerBuffer := gopacket.NewSerializeBuffer()
|
||||
|
||||
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWGEBPFProxy_connStore(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
if p != 1 {
|
||||
t.Errorf("invalid initial port: %d", wgProxy.lastUsedPort)
|
||||
}
|
||||
|
||||
numOfConns := 10
|
||||
for i := 0; i < numOfConns; i++ {
|
||||
p, _ = wgProxy.storeTurnConn(nil)
|
||||
}
|
||||
if p != uint16(numOfConns)+1 {
|
||||
t.Errorf("invalid last used port: %d, expected: %d", p, numOfConns+1)
|
||||
}
|
||||
if len(wgProxy.turnConnStore) != numOfConns+1 {
|
||||
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), numOfConns+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
wgProxy.lastUsedPort = 65535
|
||||
p, _ := wgProxy.storeTurnConn(nil)
|
||||
|
||||
if len(wgProxy.turnConnStore) != 2 {
|
||||
t.Errorf("invalid store size: %d, expected: %d", len(wgProxy.turnConnStore), 2)
|
||||
}
|
||||
|
||||
if p != 2 {
|
||||
t.Errorf("invalid last used port: %d, expected: %d", p, 2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
|
||||
wgProxy := NewWGEBPFProxy(1)
|
||||
|
||||
for i := 0; i < 65535; i++ {
|
||||
_, _ = wgProxy.storeTurnConn(nil)
|
||||
}
|
||||
|
||||
_, err := wgProxy.storeTurnConn(nil)
|
||||
if err == nil {
|
||||
t.Errorf("invalid turn conn store calculation")
|
||||
}
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
e.cancel()
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to write out turn pkg to local conn: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
)
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
ebpfProxy *ebpf.WGEBPFProxy
|
||||
}
|
||||
|
||||
func NewFactory(userspace bool, wgPort int) *Factory {
|
||||
f := &Factory{wgPort: wgPort}
|
||||
|
||||
if userspace {
|
||||
return f
|
||||
}
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
|
||||
err := ebpfProxy.Listen()
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
|
||||
return f
|
||||
}
|
||||
|
||||
f.ebpfProxy = ebpfProxy
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
if w.ebpfProxy != nil {
|
||||
p := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: w.ebpfProxy,
|
||||
}
|
||||
return p
|
||||
}
|
||||
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
if w.ebpfProxy == nil {
|
||||
return nil
|
||||
}
|
||||
return w.ebpfProxy.Free()
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package wgproxy
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
|
||||
type Factory struct {
|
||||
wgPort int
|
||||
}
|
||||
|
||||
func NewFactory(_ bool, wgPort int) *Factory {
|
||||
return &Factory{wgPort: wgPort}
|
||||
}
|
||||
|
||||
func (w *Factory) GetProxy() Proxy {
|
||||
return usp.NewWGUserSpaceProxy(w.wgPort)
|
||||
}
|
||||
|
||||
func (w *Factory) Free() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// Proxy is a transfer layer between the relayed connection and the WireGuard
|
||||
type Proxy interface {
|
||||
AddTurnConn(ctx context.Context, turnConn net.Conn) error
|
||||
EndpointAddr() *net.UDPAddr
|
||||
Work()
|
||||
Pause()
|
||||
CloseConn() error
|
||||
}
|
||||
@@ -1,128 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("trace", "console")
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newMockConn() *mocConn {
|
||||
return &mocConn{
|
||||
closeChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) Read(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Write(b []byte) (n int, err error) {
|
||||
<-m.closeChan
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (m *mocConn) Close() error {
|
||||
if m.closed == true {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
close(m.closeChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mocConn) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{
|
||||
IP: net.ParseIP("172.16.254.1"),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mocConn) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: usp.NewWGUserSpaceProxy(51830),
|
||||
},
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
}
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
package usp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// WGUserSpaceProxy proxies
|
||||
type WGUserSpaceProxy struct {
|
||||
localWGListenPort int
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
}
|
||||
|
||||
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
|
||||
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
|
||||
log.Debugf("Initializing new user space proxy with port %d", wgPort)
|
||||
p := &WGUserSpaceProxy{
|
||||
localWGListenPort: wgPort,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AddTurnConn
|
||||
// The provided Context must be non-nil. If the context expires before
|
||||
// the connection is complete, an error is returned. Once successfully
|
||||
// connected, any expiration of the context will not affect the
|
||||
// connection.
|
||||
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
|
||||
dialer := net.Dialer{}
|
||||
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||
if err != nil {
|
||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
|
||||
if p.localConn == nil {
|
||||
return nil
|
||||
}
|
||||
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
|
||||
return endpointUdpAddr
|
||||
}
|
||||
|
||||
// Work starts the proxy or resumes it if it was paused
|
||||
func (p *WGUserSpaceProxy) Work() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
func (p *WGUserSpaceProxy) Pause() {
|
||||
if p.remoteConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUserSpaceProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
return p.close()
|
||||
}
|
||||
|
||||
func (p *WGUserSpaceProxy) close() error {
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
// prevent double close
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
if err := p.remoteConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
return errors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
// proxyToRemote proxies from Wireguard to the RemoteKey
|
||||
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to remote loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for ctx.Err() == nil {
|
||||
n, err := p.localConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to read from wg interface conn: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.remoteConn.Write(buf[:n])
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("failed to write to remote conn: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToLocal proxies from the Remote peer to local WireGuard
|
||||
// if the proxy is paused it will drain the remote conn and drop the packets
|
||||
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
|
||||
defer func() {
|
||||
if err := p.close(); err != nil {
|
||||
log.Warnf("error in proxy to local loop: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("failed to write to wg interface conn: %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user