Add pause function for proxies

This commit is contained in:
Zoltán Papp
2024-10-03 01:24:05 +02:00
parent 8934453b30
commit 9d75cc3273
4 changed files with 154 additions and 69 deletions

View File

@@ -5,7 +5,6 @@ package ebpf
import (
"context"
"fmt"
"io"
"net"
"os"
"sync"
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
}
// AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result)
}
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {

View File

@@ -5,7 +5,11 @@ package ebpf
import (
"context"
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -13,22 +17,62 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
ctx context.Context
cancel context.CancelFunc
wgEndpointPort uint16
pausedMu sync.Mutex
paused bool
isStarted bool
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointPort = uint16(addr.Port)
return addr, err
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
func (p *ProxyWrapper) Resume() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
@@ -42,3 +86,32 @@ func (e *ProxyWrapper) CloseConn() error {
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(p.wgEndpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointPort, err)
}
return
}
if err := p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointPort); err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}

View File

@@ -8,5 +8,7 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
Work()
Pause()
CloseConn() error
}

View File

@@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@@ -33,26 +37,53 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p
}
// AddTurnConn start the proxy with the given remote conn
// AddTurnConn
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
p.cancel()
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn
p.remoteConn = remoteConn
return p.localConn.LocalAddr(), err
}
func (p *WGUserSpaceProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
func (p *WGUserSpaceProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
if p.cancel == nil {
@@ -85,7 +116,7 @@ func (p *WGUserSpaceProxy) close() error {
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
@@ -93,10 +124,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
for ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
@@ -105,7 +136,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
if ctx.Err() != nil {
return
}
@@ -116,7 +147,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
@@ -124,23 +155,33 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
for ctx.Err() == nil {
for {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}