diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1..4cf0f93fe 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840..d4e11083b 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -5,7 +5,11 @@ package ebpf import ( "context" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,22 +17,62 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread + ctx context.Context + cancel context.CancelFunc + + wgEndpointPort uint16 + + pausedMu sync.Mutex + paused bool + isStarted bool } -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) - +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() return nil, fmt.Errorf("add turn conn: %w", err) } - e.remoteConn = remoteConn - e.cancel = cancel + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointPort = uint16(addr.Port) return addr, err } +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() +} + +func (p *ProxyWrapper) Resume() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() +} + // CloseConn close the remoteConn and automatically remove the conn instance from the map func (e *ProxyWrapper) CloseConn() error { if e.cancel == nil { @@ -42,3 +86,32 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(p.wgEndpointPort) + + var ( + err error + n int + ) + buf := make([]byte, 1500) + for ctx.Err() == nil { + n, err = p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + if err != io.EOF { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointPort, err) + } + return + } + + if err := p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointPort); err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd1..55093df1c 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -8,5 +8,7 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d8..8f5044443 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,26 +37,53 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) + p.cancel() return nil, err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn return p.localConn.LocalAddr(), err } +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() +} + // CloseConn close the localConn func (p *WGUserSpaceProxy) CloseConn() error { if p.cancel == nil { @@ -85,7 +116,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +124,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +136,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +147,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,23 +155,33 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { - n, err := p.remoteConn.Read(buf) - if err != nil { - if p.ctx.Err() != nil { + for ctx.Err() == nil { + for { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } - log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) - return - } - _, err = p.localConn.Write(buf[:n]) - if err != nil { - if p.ctx.Err() != nil { - return + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Debugf("failed to write to wg interface conn: %s", err) + continue } - log.Debugf("failed to write to wg interface conn: %s", err) - continue } } }