Files
gerbil/relay/relay.go
Laurence 697f4131e7 enhancement: base context + errgroup; propagate cancellation; graceful shutdown
- main: add base context via signal.NotifyContext; establish errgroup and use it to supervise background tasks; convert ticker to context-aware periodicBandwidthCheck; run HTTP server under errgroup and add graceful shutdown; treat context.Canceled as normal exit
- relay: thread parent context through UDPProxyServer; add cancel func; make packet reader, workers, and cleanup tickers exit on ctx.Done; Stop cancels, closes listener and downstream UDP connections, and closes packet channel to drain workers
- proxy: drop earlier parent context hook for SNI proxy per review; rely on existing Stop() for graceful shutdown

Benefits:
- unified lifecycle and deterministic shutdown across components
- prevents leaked goroutines/tickers and closes sockets cleanly
- consolidated error handling via g.Wait(), with context cancellation treated as non-error
- sets foundation for child errgroups and future structured concurrency
2025-11-16 06:00:32 +00:00

1030 lines
31 KiB
Go

package relay
import (
"context"
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type EncryptedHolePunchMessage struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}
type HolePunchMessage struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
Token string `json:"token"`
}
type ClientEndpoint struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
Token string `json:"token"`
IP string `json:"ip"`
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
ReachableAt string `json:"reachableAt"`
PublicKey string `json:"publicKey"`
}
// Updated to support multiple destination peers
type ProxyMapping struct {
Destinations []PeerDestination `json:"destinations"`
LastUsed time.Time `json:"-"` // Not serialized, used for cleanup
}
type PeerDestination struct {
DestinationIP string `json:"destinationIP"`
DestinationPort int `json:"destinationPort"`
}
type DestinationConn struct {
conn *net.UDPConn
lastUsed time.Time
}
// Type for storing WireGuard handshake information
type WireGuardSession struct {
ReceiverIndex uint32
SenderIndex uint32
DestAddr *net.UDPAddr
LastSeen time.Time
}
// Type for tracking bidirectional communication patterns to rebuild sessions
type CommunicationPattern struct {
FromClient *net.UDPAddr // The client address
ToDestination *net.UDPAddr // The destination address
ClientIndex uint32 // The receiver index seen from client
DestIndex uint32 // The receiver index seen from destination
LastFromClient time.Time // Last packet from client to destination
LastFromDest time.Time // Last packet from destination to client
PacketCount int // Number of packets observed
}
type InitialMappings struct {
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
}
// Packet is a simple struct to hold the packet data and sender info.
type Packet struct {
data []byte
remoteAddr *net.UDPAddr
n int
}
// WireGuard message types
const (
WireGuardMessageTypeHandshakeInitiation = 1
WireGuardMessageTypeHandshakeResponse = 2
WireGuardMessageTypeCookieReply = 3
WireGuardMessageTypeTransportData = 4
)
// --- End Types ---
// bufferPool allows reusing buffers to reduce allocations.
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1500)
},
}
// UDPProxyServer has a channel for incoming packets.
type UDPProxyServer struct {
addr string
serverURL string
conn *net.UDPConn
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
privateKey wgtypes.Key
packetChan chan Packet
ctx context.Context
cancel context.CancelFunc
// Session tracking for WireGuard peers
// Key format: "senderIndex:receiverIndex"
wgSessions sync.Map
// Communication pattern tracking for rebuilding sessions
// Key format: "clientIP:clientPort-destIP:destPort"
commPatterns sync.Map
// ReachableAt is the URL where this server can be reached
ReachableAt string
}
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
ctx, cancel := context.WithCancel(parentCtx)
return &UDPProxyServer{
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 1000),
ReachableAt: reachableAt,
ctx: ctx,
cancel: cancel,
}
}
// Start sets up the UDP listener, worker pool, and begins reading packets.
func (s *UDPProxyServer) Start() error {
// Fetch initial mappings.
if err := s.fetchInitialMappings(); err != nil {
return fmt.Errorf("failed to fetch initial mappings: %v", err)
}
udpAddr, err := net.ResolveUDPAddr("udp", s.addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
s.conn = conn
logger.Info("UDP server listening on %s", s.addr)
// Start a fixed number of worker goroutines.
workerCount := 10 // TODO: Make this configurable or pick it better!
for i := 0; i < workerCount; i++ {
go s.packetWorker()
}
// Start the goroutine that reads packets from the UDP socket.
go s.readPackets()
// Start the idle connection cleanup routine.
go s.cleanupIdleConnections()
// Start the session cleanup routine
go s.cleanupIdleSessions()
// Start the proxy mapping cleanup routine
go s.cleanupIdleProxyMappings()
// Start the communication pattern cleanup routine
go s.cleanupIdleCommunicationPatterns()
return nil
}
func (s *UDPProxyServer) Stop() {
// Signal all background goroutines to stop
if s.cancel != nil {
s.cancel()
}
// Close listener to unblock reads
if s.conn != nil {
_ = s.conn.Close()
}
// Close all downstream UDP connections
s.connections.Range(func(key, value interface{}) bool {
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
_ = dc.conn.Close()
}
return true
})
// Close packet channel to stop workers
select {
case <-s.ctx.Done():
default:
}
close(s.packetChan)
}
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
// Exit promptly if context is canceled
select {
case <-s.ctx.Done():
return
default:
}
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
// If we're shutting down, exit
select {
case <-s.ctx.Done():
bufferPool.Put(buf[:1500])
return
default:
logger.Error("Error reading UDP packet: %v", err)
bufferPool.Put(buf[:1500])
continue
}
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
}
// packetWorker processes incoming packets from the channel.
func (s *UDPProxyServer) packetWorker() {
for packet := range s.packetChan {
// Determine packet type by inspecting the first byte.
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
// Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else {
// Process as an encrypted hole punch message
var encMsg EncryptedHolePunchMessage
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
logger.Error("Error unmarshaling encrypted message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
if encMsg.EphemeralPublicKey == "" {
logger.Error("Received malformed message without ephemeral key")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
// This appears to be an encrypted message
decryptedData, err := s.decryptMessage(encMsg)
if err != nil {
logger.Error("Failed to decrypt message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
// Process the decrypted hole punch message
var msg HolePunchMessage
if err := json.Unmarshal(decryptedData, &msg); err != nil {
logger.Error("Error unmarshaling decrypted message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
endpoint := ClientEndpoint{
NewtID: msg.NewtID,
OlmID: msg.OlmID,
Token: msg.Token,
IP: packet.remoteAddr.IP.String(),
Port: packet.remoteAddr.Port,
Timestamp: time.Now().Unix(),
ReachableAt: s.ReachableAt,
PublicKey: s.privateKey.PublicKey().String(),
}
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
s.notifyServer(endpoint)
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
}
// Return the buffer to the pool for reuse.
bufferPool.Put(packet.data[:1500])
}
}
// decryptMessage decrypts the message using the server's private key
func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) {
// Parse the ephemeral public key
ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err)
}
// Use X25519 for key exchange instead of ScalarMult
sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create the AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Verify nonce size
if len(encMsg.Nonce) != aead.NonceSize() {
return nil, fmt.Errorf("invalid nonce size")
}
// Decrypt the ciphertext
plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt message: %v", err)
}
return plaintext, nil
}
func (s *UDPProxyServer) fetchInitialMappings() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String())))
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
if err != nil {
return fmt.Errorf("failed to fetch mappings: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %v", err)
}
logger.Info("Received initial mappings: %s", string(data))
var initialMappings InitialMappings
if err := json.Unmarshal(data, &initialMappings); err != nil {
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
}
// Store mappings in our sync.Map.
for key, mapping := range initialMappings.Mappings {
// Initialize LastUsed timestamp for initial mappings
mapping.LastUsed = time.Now()
s.proxyMappings.Store(key, mapping)
}
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
return nil
}
// Extract WireGuard message indices
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
if len(packet) < 12 {
return 0, 0, false
}
messageType := packet[0]
if messageType == WireGuardMessageTypeHandshakeInitiation {
// Handshake initiation: extract sender index at offset 4
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
return 0, senderIndex, true
} else if messageType == WireGuardMessageTypeHandshakeResponse {
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
return receiverIndex, senderIndex, true
} else if messageType == WireGuardMessageTypeTransportData {
// Transport data: extract receiver index at offset 4
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
return receiverIndex, 0, true
}
return 0, 0, false
}
// Updated to handle multi-peer WireGuard communication
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
if len(packet) == 0 {
logger.Error("Received empty packet")
return
}
messageType := packet[0]
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
if !ok {
logger.Error("Failed to extract WireGuard indices")
return
}
key := remoteAddr.String()
mappingObj, ok := s.proxyMappings.Load(key)
if !ok {
logger.Error("No proxy mapping found for %s", key)
return
}
proxyMapping := mappingObj.(ProxyMapping)
// Update the last used timestamp and store it back
proxyMapping.LastUsed = time.Now()
s.proxyMappings.Store(key, proxyMapping)
// Handle different WireGuard message types
switch messageType {
case WireGuardMessageTypeHandshakeInitiation:
// Initial handshake: forward to all peers
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake initiation: %v", err)
}
}
case WireGuardMessageTypeHandshakeResponse:
// Received handshake response: establish session mapping
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
receiverIndex, senderIndex, remoteAddr)
// Create a session key for the peer that sent the initial handshake
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
// Store the session information
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: remoteAddr,
LastSeen: time.Now(),
})
// Forward the response to the original sender
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake response: %v", err)
}
}
case WireGuardMessageTypeTransportData:
// Data packet: forward only to the established session peer
// logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
// Look up the session based on the receiver index
var destAddr *net.UDPAddr
// First check for existing sessions to see if we know where to send this packet
s.wgSessions.Range(func(k, v interface{}) bool {
session := v.(*WireGuardSession)
if session.SenderIndex == receiverIndex {
// Found matching session
destAddr = session.DestAddr
// Update last seen time
session.LastSeen = time.Now()
s.wgSessions.Store(k, session)
return false // stop iteration
}
return true // continue iteration
})
if destAddr != nil {
// We found a specific peer to forward to
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
return
}
// Track communication pattern for session rebuilding
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward transport data: %v", err)
}
} else {
// No known session, fall back to forwarding to all peers
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
// Track communication pattern for session rebuilding
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
_, err = conn.Write(packet)
if err != nil {
logger.Debug("Failed to forward transport data: %v", err)
}
}
}
default:
// Other packet types (like cookie reply)
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
// Forward to all peers
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward WireGuard packet: %v", err)
}
}
}
}
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
key := destAddr.String() + "-" + remoteAddr.String()
// Check if we have an existing connection
if conn, ok := s.connections.Load(key); ok {
destConn := conn.(*DestinationConn)
destConn.lastUsed = time.Now()
return destConn.conn, nil
}
// Create new connection
newConn, err := net.DialUDP("udp", nil, destAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
}
// Store the new connection
s.connections.Store(key, &DestinationConn{
conn: newConn,
lastUsed: time.Now(),
})
// Start a goroutine to handle responses
go s.handleResponses(newConn, destAddr, remoteAddr)
return newConn, nil
}
func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) {
buffer := make([]byte, 1500)
for {
n, err := conn.Read(buffer)
if err != nil {
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
return
}
// Process the response to track sessions if it's a WireGuard packet
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
// Store the session mapping for the handshake response
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: destAddr,
LastSeen: time.Now(),
})
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
// Track communication pattern for session rebuilding (reverse direction)
s.trackCommunicationPattern(destAddr, remoteAddr, receiverIndex, false)
}
}
// Forward the response back through the main listener
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
if err != nil {
logger.Error("Failed to forward response: %v", err)
}
}
}
// Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
if now.Sub(destConn.lastUsed) > 10*time.Minute {
destConn.conn.Close()
s.connections.Delete(key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
// New method to periodically remove idle sessions
func (s *UDPProxyServer) cleanupIdleSessions() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
if now.Sub(session.LastSeen) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
// New method to periodically remove idle proxy mappings
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.proxyMappings.Range(func(key, value interface{}) bool {
mapping := value.(ProxyMapping)
// Remove mappings that haven't been used in 30 minutes
if now.Sub(mapping.LastUsed) > 30*time.Minute {
s.proxyMappings.Delete(key)
logger.Debug("Removed idle proxy mapping: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
logger.Debug("notifyServer called with endpoint: IP=%s, Port=%d", endpoint.IP, endpoint.Port)
jsonData, err := json.Marshal(endpoint)
if err != nil {
logger.Error("Failed to marshal endpoint data: %v", err)
return
}
resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
logger.Error("Failed to notify server: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body))
return
}
// Parse the proxy mapping response
var mapping ProxyMapping
if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil {
logger.Error("Failed to decode proxy mapping: %v", err)
return
}
logger.Debug("Received proxy mapping from server: %v", mapping)
// Store the mapping with current timestamp
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
mapping.LastUsed = time.Now()
s.proxyMappings.Store(key, mapping)
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
}
// Updated to support multiple destinations
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
key := fmt.Sprintf("%s:%d", sourceIP, sourcePort)
mapping := ProxyMapping{
Destinations: destinations,
LastUsed: time.Now(),
}
s.proxyMappings.Store(key, mapping)
}
// OnPeerAdded clears connections and sessions for a specific WireGuard IP to allow re-establishment
func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
logger.Info("Clearing connections for added peer with WG IP: %s", wgIP)
s.clearConnectionsForWGIP(wgIP)
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
// s.clearProxyMappingsForWGIP(wgIP)
}
// OnPeerRemoved clears connections and sessions for a specific WireGuard IP
func (s *UDPProxyServer) OnPeerRemoved(wgIP string) {
logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP)
s.clearConnectionsForWGIP(wgIP)
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
// s.clearProxyMappingsForWGIP(wgIP)
}
// clearConnectionsForWGIP removes all connections associated with a specific WireGuard IP
func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) {
var keysToDelete []string
s.connections.Range(func(key, value interface{}) bool {
keyStr := key.(string)
destConn := value.(*DestinationConn)
// Connection keys are in format "destAddr-remoteAddr"
// Check if either destination or remote address contains the WG IP
if containsIP(keyStr, wgIP) {
keysToDelete = append(keysToDelete, keyStr)
destConn.conn.Close()
logger.Debug("Closing connection for WG IP %s: %s", wgIP, keyStr)
}
return true
})
// Delete the connections
for _, key := range keysToDelete {
s.connections.Delete(key)
}
logger.Info("Cleared %d connections for WG IP: %s", len(keysToDelete), wgIP)
}
// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP
func (s *UDPProxyServer) clearSessionsForIP(ip string) {
var keysToDelete []string
s.wgSessions.Range(func(key, value interface{}) bool {
keyStr := key.(string)
session := value.(*WireGuardSession)
// Check if the session's destination address contains the WG IP
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
keysToDelete = append(keysToDelete, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
}
return true
})
// Delete the sessions
for _, key := range keysToDelete {
s.wgSessions.Delete(key)
}
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
}
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP
// func (s *UDPProxyServer) clearProxyMappingsForWGIP(wgIP string) {
// var keysToDelete []string
// s.proxyMappings.Range(func(key, value interface{}) bool {
// keyStr := key.(string)
// mapping := value.(ProxyMapping)
// // Check if any destination in the mapping contains the WG IP
// for _, dest := range mapping.Destinations {
// if dest.DestinationIP == wgIP {
// keysToDelete = append(keysToDelete, keyStr)
// logger.Debug("Marking proxy mapping for deletion for WG IP %s: %s -> %s:%d", wgIP, keyStr, dest.DestinationIP, dest.DestinationPort)
// break // Found one destination, no need to check others in this mapping
// }
// }
// return true
// })
// // Delete the proxy mappings
// for _, key := range keysToDelete {
// s.proxyMappings.Delete(key)
// logger.Debug("Deleted proxy mapping: %s", key)
// }
// logger.Info("Cleared %d proxy mappings for WG IP: %s", len(keysToDelete), wgIP)
// }
// containsIP checks if a connection key string contains the specified IP address
func containsIP(connectionKey, ip string) bool {
// Connection keys are in format "destIP:destPort-remoteIP:remotePort"
// Check if the IP appears at the beginning (destination) or after the dash (remote)
ipWithColon := ip + ":"
// Check if connection key starts with the IP (destination address)
if len(connectionKey) >= len(ipWithColon) && connectionKey[:len(ipWithColon)] == ipWithColon {
return true
}
// Check if connection key contains the IP after a dash (remote address)
dashIndex := -1
for i := 0; i < len(connectionKey); i++ {
if connectionKey[i] == '-' {
dashIndex = i
break
}
}
if dashIndex != -1 && dashIndex+1 < len(connectionKey) {
remainingPart := connectionKey[dashIndex+1:]
if len(remainingPart) >= len(ip)+1 && remainingPart[:len(ip)+1] == ipWithColon {
return true
}
}
return false
}
// UpdateDestinationInMappings updates all proxy mappings that contain the old destination with the new destination
// Returns the number of mappings that were updated
func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestination) int {
updatedCount := 0
s.proxyMappings.Range(func(key, value interface{}) bool {
keyStr := key.(string)
mapping := value.(ProxyMapping)
updated := false
// Check each destination in the mapping
for i, dest := range mapping.Destinations {
if dest.DestinationIP == oldDest.DestinationIP && dest.DestinationPort == oldDest.DestinationPort {
// Update this destination
mapping.Destinations[i] = newDest
updated = true
logger.Debug("Updated destination in mapping %s: %s:%d -> %s:%d",
keyStr, oldDest.DestinationIP, oldDest.DestinationPort,
newDest.DestinationIP, newDest.DestinationPort)
}
}
// If we updated any destinations, store the updated mapping back
if updated {
mapping.LastUsed = time.Now()
s.proxyMappings.Store(keyStr, mapping)
updatedCount++
}
return true // continue iteration
})
if updatedCount > 0 {
logger.Info("Updated %d proxy mappings from %s:%d to %s:%d",
updatedCount, oldDest.DestinationIP, oldDest.DestinationPort,
newDest.DestinationIP, newDest.DestinationPort)
}
return updatedCount
}
// trackCommunicationPattern tracks bidirectional communication patterns to rebuild sessions
func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr, receiverIndex uint32, fromClient bool) {
var clientAddr, destAddr *net.UDPAddr
var clientIndex, destIndex uint32
if fromClient {
clientAddr = fromAddr
destAddr = toAddr
clientIndex = receiverIndex
destIndex = 0 // We don't know the destination index yet
} else {
clientAddr = toAddr
destAddr = fromAddr
clientIndex = 0 // We don't know the client index yet
destIndex = receiverIndex
}
patternKey := fmt.Sprintf("%s-%s", clientAddr.String(), destAddr.String())
now := time.Now()
if existingPattern, ok := s.commPatterns.Load(patternKey); ok {
pattern := existingPattern.(*CommunicationPattern)
// Update the pattern
if fromClient {
pattern.LastFromClient = now
if pattern.ClientIndex == 0 {
pattern.ClientIndex = clientIndex
}
} else {
pattern.LastFromDest = now
if pattern.DestIndex == 0 {
pattern.DestIndex = destIndex
}
}
pattern.PacketCount++
s.commPatterns.Store(patternKey, pattern)
// Check if we have bidirectional communication and can rebuild a session
s.tryRebuildSession(pattern)
} else {
// Create new pattern
pattern := &CommunicationPattern{
FromClient: clientAddr,
ToDestination: destAddr,
ClientIndex: clientIndex,
DestIndex: destIndex,
PacketCount: 1,
}
if fromClient {
pattern.LastFromClient = now
} else {
pattern.LastFromDest = now
}
s.commPatterns.Store(patternKey, pattern)
}
}
// tryRebuildSession attempts to rebuild a WireGuard session from communication patterns
func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// Check if we have bidirectional communication within a reasonable time window
timeDiff := pattern.LastFromClient.Sub(pattern.LastFromDest)
if timeDiff < 0 {
timeDiff = -timeDiff
}
// Only rebuild if we have recent bidirectional communication and both indices
if timeDiff < 30*time.Second && pattern.ClientIndex != 0 && pattern.DestIndex != 0 && pattern.PacketCount >= 4 {
// Create session mapping: client's index maps to destination
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
// Check if we already have this session
if _, exists := s.wgSessions.Load(sessionKey); !exists {
session := &WireGuardSession{
ReceiverIndex: pattern.DestIndex,
SenderIndex: pattern.ClientIndex,
DestAddr: pattern.ToDestination,
LastSeen: time.Now(),
}
s.wgSessions.Store(sessionKey, session)
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
}
}
}
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.commPatterns.Range(func(key, value interface{}) bool {
pattern := value.(*CommunicationPattern)
// Get the most recent activity
lastActivity := pattern.LastFromClient
if pattern.LastFromDest.After(lastActivity) {
lastActivity = pattern.LastFromDest
}
// Remove patterns that haven't had activity in 20 minutes
if now.Sub(lastActivity) > 20*time.Minute {
s.commPatterns.Delete(key)
logger.Debug("Removed idle communication pattern: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}