Monitor network changes and restart engine on detection (#1904)

This commit is contained in:
Viktor Liu
2024-05-07 18:50:34 +02:00
committed by GitHub
parent 2e0047daea
commit 920877964f
38 changed files with 1027 additions and 190 deletions

View File

@@ -3,6 +3,7 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
@@ -22,7 +23,11 @@ import (
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int
@@ -34,7 +39,7 @@ type WGEBPFProxy struct {
}
// NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
@@ -42,11 +47,13 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy
}
// Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) Listen() error {
// listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
@@ -72,7 +79,7 @@ func (p *WGEBPFProxy) Listen() error {
if err != nil {
cErr := p.Free()
if cErr != nil {
log.Errorf("failed to close the wgproxy: %s", cErr)
log.Errorf("Failed to close the wgproxy: %s", cErr)
}
return err
}
@@ -102,6 +109,7 @@ func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
// CloseConn doing nothing because this type of proxy implementation does not store the connection
func (p *WGEBPFProxy) CloseConn() error {
p.cancel()
return nil
}
@@ -131,19 +139,27 @@ func (p *WGEBPFProxy) Free() error {
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
buf := make([]byte, 1500)
var err error
defer func() {
p.removeTurnConn(endpointPort)
}()
for {
n, err := remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
p.removeTurnConn(endpointPort)
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
select {
case <-p.ctx.Done():
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
err = p.sendPkg(buf[:n], endpointPort)
if err != nil {
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
}
@@ -152,23 +168,28 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
select {
case <-p.ctx.Done():
return
}
default:
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Infof("turn conn not found by port: %d", addr.Port)
continue
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Infof("turn conn not found by port: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
}
}
}
@@ -266,15 +287,17 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error {
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return err
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload)
if err != nil {
return err
return fmt.Errorf("serialize layers: %w", err)
}
_, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost})
return err
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}