diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d4ad2396..1b740388d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -82,8 +82,6 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager @@ -106,7 +104,8 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool @@ -257,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() + } + + if wgProxy != nil { + wgProxy.Work() + } + + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyICE = wgProxy - conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) - - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) + wgConfigWorkaround() - - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - log.Debugf("relay connection is disconnected") + conn.log.Debugf("relay connection is disconnected") if conn.currentConnPriority == connPriorityRelay { - log.Debugf("clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } if conn.wgProxyRelay != nil { - conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) - } - return nil, nil, err + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) + } + } + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() } - return ep, wgProxy, nil } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1..e850f4533 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840..b6a8ac452 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread + ctx context.Context + cancel context.CancelFunc + + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool } -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) - +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd1..558121cdd 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be55..b88ff3f83 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d8..f73500717 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err)