mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-08 05:56:40 +00:00
966 lines
30 KiB
Go
966 lines
30 KiB
Go
package relay
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fosrl/gerbil/logger"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/curve25519"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
type EncryptedHolePunchMessage struct {
|
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
Nonce []byte `json:"nonce"`
|
|
Ciphertext []byte `json:"ciphertext"`
|
|
}
|
|
|
|
type HolePunchMessage struct {
|
|
OlmID string `json:"olmId"`
|
|
NewtID string `json:"newtId"`
|
|
Token string `json:"token"`
|
|
}
|
|
|
|
type ClientEndpoint struct {
|
|
OlmID string `json:"olmId"`
|
|
NewtID string `json:"newtId"`
|
|
Token string `json:"token"`
|
|
IP string `json:"ip"`
|
|
Port int `json:"port"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
ReachableAt string `json:"reachableAt"`
|
|
PublicKey string `json:"publicKey"`
|
|
}
|
|
|
|
// Updated to support multiple destination peers
|
|
type ProxyMapping struct {
|
|
Destinations []PeerDestination `json:"destinations"`
|
|
LastUsed time.Time `json:"-"` // Not serialized, used for cleanup
|
|
}
|
|
|
|
type PeerDestination struct {
|
|
DestinationIP string `json:"destinationIP"`
|
|
DestinationPort int `json:"destinationPort"`
|
|
}
|
|
|
|
type DestinationConn struct {
|
|
conn *net.UDPConn
|
|
lastUsed time.Time
|
|
}
|
|
|
|
// Type for storing WireGuard handshake information
|
|
type WireGuardSession struct {
|
|
ReceiverIndex uint32
|
|
SenderIndex uint32
|
|
DestAddr *net.UDPAddr
|
|
LastSeen time.Time
|
|
}
|
|
|
|
// Type for tracking bidirectional communication patterns to rebuild sessions
|
|
type CommunicationPattern struct {
|
|
FromClient *net.UDPAddr // The client address
|
|
ToDestination *net.UDPAddr // The destination address
|
|
ClientIndex uint32 // The receiver index seen from client
|
|
DestIndex uint32 // The receiver index seen from destination
|
|
LastFromClient time.Time // Last packet from client to destination
|
|
LastFromDest time.Time // Last packet from destination to client
|
|
PacketCount int // Number of packets observed
|
|
}
|
|
|
|
type InitialMappings struct {
|
|
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
|
}
|
|
|
|
// Packet is a simple struct to hold the packet data and sender info.
|
|
type Packet struct {
|
|
data []byte
|
|
remoteAddr *net.UDPAddr
|
|
n int
|
|
}
|
|
|
|
// WireGuard message types
|
|
const (
|
|
WireGuardMessageTypeHandshakeInitiation = 1
|
|
WireGuardMessageTypeHandshakeResponse = 2
|
|
WireGuardMessageTypeCookieReply = 3
|
|
WireGuardMessageTypeTransportData = 4
|
|
)
|
|
|
|
// --- End Types ---
|
|
|
|
// bufferPool allows reusing buffers to reduce allocations.
|
|
var bufferPool = sync.Pool{
|
|
New: func() interface{} {
|
|
return make([]byte, 1500)
|
|
},
|
|
}
|
|
|
|
// UDPProxyServer has a channel for incoming packets.
|
|
type UDPProxyServer struct {
|
|
addr string
|
|
serverURL string
|
|
conn *net.UDPConn
|
|
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
|
|
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
|
privateKey wgtypes.Key
|
|
packetChan chan Packet
|
|
|
|
// Session tracking for WireGuard peers
|
|
// Key format: "senderIndex:receiverIndex"
|
|
wgSessions sync.Map
|
|
// Communication pattern tracking for rebuilding sessions
|
|
// Key format: "clientIP:clientPort-destIP:destPort"
|
|
commPatterns sync.Map
|
|
// ReachableAt is the URL where this server can be reached
|
|
ReachableAt string
|
|
}
|
|
|
|
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
|
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
|
return &UDPProxyServer{
|
|
addr: addr,
|
|
serverURL: serverURL,
|
|
privateKey: privateKey,
|
|
packetChan: make(chan Packet, 1000),
|
|
ReachableAt: reachableAt,
|
|
}
|
|
}
|
|
|
|
// Start sets up the UDP listener, worker pool, and begins reading packets.
|
|
func (s *UDPProxyServer) Start() error {
|
|
// Fetch initial mappings.
|
|
if err := s.fetchInitialMappings(); err != nil {
|
|
return fmt.Errorf("failed to fetch initial mappings: %v", err)
|
|
}
|
|
|
|
udpAddr, err := net.ResolveUDPAddr("udp", s.addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn, err := net.ListenUDP("udp", udpAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.conn = conn
|
|
logger.Info("UDP server listening on %s", s.addr)
|
|
|
|
// Start a fixed number of worker goroutines.
|
|
workerCount := 10 // TODO: Make this configurable or pick it better!
|
|
for i := 0; i < workerCount; i++ {
|
|
go s.packetWorker()
|
|
}
|
|
|
|
// Start the goroutine that reads packets from the UDP socket.
|
|
go s.readPackets()
|
|
|
|
// Start the idle connection cleanup routine.
|
|
go s.cleanupIdleConnections()
|
|
|
|
// Start the session cleanup routine
|
|
go s.cleanupIdleSessions()
|
|
|
|
// Start the proxy mapping cleanup routine
|
|
go s.cleanupIdleProxyMappings()
|
|
|
|
// Start the communication pattern cleanup routine
|
|
go s.cleanupIdleCommunicationPatterns()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UDPProxyServer) Stop() {
|
|
s.conn.Close()
|
|
}
|
|
|
|
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
|
func (s *UDPProxyServer) readPackets() {
|
|
for {
|
|
buf := bufferPool.Get().([]byte)
|
|
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
logger.Error("Error reading UDP packet: %v", err)
|
|
continue
|
|
}
|
|
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
|
}
|
|
}
|
|
|
|
// packetWorker processes incoming packets from the channel.
|
|
func (s *UDPProxyServer) packetWorker() {
|
|
for packet := range s.packetChan {
|
|
// Determine packet type by inspecting the first byte.
|
|
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
|
// Process as a WireGuard packet.
|
|
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
|
} else {
|
|
// Process as an encrypted hole punch message
|
|
var encMsg EncryptedHolePunchMessage
|
|
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
|
logger.Error("Error unmarshaling encrypted message: %v", err)
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
if encMsg.EphemeralPublicKey == "" {
|
|
logger.Error("Received malformed message without ephemeral key")
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
// This appears to be an encrypted message
|
|
decryptedData, err := s.decryptMessage(encMsg)
|
|
if err != nil {
|
|
logger.Error("Failed to decrypt message: %v", err)
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
// Process the decrypted hole punch message
|
|
var msg HolePunchMessage
|
|
if err := json.Unmarshal(decryptedData, &msg); err != nil {
|
|
logger.Error("Error unmarshaling decrypted message: %v", err)
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
endpoint := ClientEndpoint{
|
|
NewtID: msg.NewtID,
|
|
OlmID: msg.OlmID,
|
|
Token: msg.Token,
|
|
IP: packet.remoteAddr.IP.String(),
|
|
Port: packet.remoteAddr.Port,
|
|
Timestamp: time.Now().Unix(),
|
|
ReachableAt: s.ReachableAt,
|
|
PublicKey: s.privateKey.PublicKey().String(),
|
|
}
|
|
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
|
s.notifyServer(endpoint)
|
|
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
|
|
}
|
|
// Return the buffer to the pool for reuse.
|
|
bufferPool.Put(packet.data[:1500])
|
|
}
|
|
}
|
|
|
|
// decryptMessage decrypts the message using the server's private key
|
|
func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) {
|
|
// Parse the ephemeral public key
|
|
ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err)
|
|
}
|
|
|
|
// Use X25519 for key exchange instead of ScalarMult
|
|
sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
|
}
|
|
|
|
// Create the AEAD cipher using the shared secret
|
|
aead, err := chacha20poly1305.New(sharedSecret)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
|
}
|
|
|
|
// Verify nonce size
|
|
if len(encMsg.Nonce) != aead.NonceSize() {
|
|
return nil, fmt.Errorf("invalid nonce size")
|
|
}
|
|
|
|
// Decrypt the ciphertext
|
|
plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt message: %v", err)
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
func (s *UDPProxyServer) fetchInitialMappings() error {
|
|
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String())))
|
|
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch mappings: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("server returned non-OK status: %d, body: %s",
|
|
resp.StatusCode, string(body))
|
|
}
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read response body: %v", err)
|
|
}
|
|
logger.Info("Received initial mappings: %s", string(data))
|
|
var initialMappings InitialMappings
|
|
if err := json.Unmarshal(data, &initialMappings); err != nil {
|
|
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
|
|
}
|
|
// Store mappings in our sync.Map.
|
|
for key, mapping := range initialMappings.Mappings {
|
|
// Initialize LastUsed timestamp for initial mappings
|
|
mapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(key, mapping)
|
|
}
|
|
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
|
|
return nil
|
|
}
|
|
|
|
// Extract WireGuard message indices
|
|
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
|
if len(packet) < 12 {
|
|
return 0, 0, false
|
|
}
|
|
|
|
messageType := packet[0]
|
|
if messageType == WireGuardMessageTypeHandshakeInitiation {
|
|
// Handshake initiation: extract sender index at offset 4
|
|
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
|
return 0, senderIndex, true
|
|
} else if messageType == WireGuardMessageTypeHandshakeResponse {
|
|
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
|
|
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
|
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
|
|
return receiverIndex, senderIndex, true
|
|
} else if messageType == WireGuardMessageTypeTransportData {
|
|
// Transport data: extract receiver index at offset 4
|
|
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
|
|
return receiverIndex, 0, true
|
|
}
|
|
|
|
return 0, 0, false
|
|
}
|
|
|
|
// Updated to handle multi-peer WireGuard communication
|
|
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
|
if len(packet) == 0 {
|
|
logger.Error("Received empty packet")
|
|
return
|
|
}
|
|
|
|
messageType := packet[0]
|
|
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
|
|
|
|
if !ok {
|
|
logger.Error("Failed to extract WireGuard indices")
|
|
return
|
|
}
|
|
|
|
key := remoteAddr.String()
|
|
mappingObj, ok := s.proxyMappings.Load(key)
|
|
if !ok {
|
|
logger.Error("No proxy mapping found for %s", key)
|
|
return
|
|
}
|
|
|
|
proxyMapping := mappingObj.(ProxyMapping)
|
|
// Update the last used timestamp and store it back
|
|
proxyMapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(key, proxyMapping)
|
|
|
|
// Handle different WireGuard message types
|
|
switch messageType {
|
|
case WireGuardMessageTypeHandshakeInitiation:
|
|
// Initial handshake: forward to all peers
|
|
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
|
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Error("Failed to forward handshake initiation: %v", err)
|
|
}
|
|
}
|
|
|
|
case WireGuardMessageTypeHandshakeResponse:
|
|
// Received handshake response: establish session mapping
|
|
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
|
|
receiverIndex, senderIndex, remoteAddr)
|
|
|
|
// Create a session key for the peer that sent the initial handshake
|
|
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
|
|
|
|
// Store the session information
|
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
|
ReceiverIndex: receiverIndex,
|
|
SenderIndex: senderIndex,
|
|
DestAddr: remoteAddr,
|
|
LastSeen: time.Now(),
|
|
})
|
|
|
|
// Forward the response to the original sender
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Error("Failed to forward handshake response: %v", err)
|
|
}
|
|
}
|
|
|
|
case WireGuardMessageTypeTransportData:
|
|
// Data packet: forward only to the established session peer
|
|
// logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
|
|
|
|
// Look up the session based on the receiver index
|
|
var destAddr *net.UDPAddr
|
|
|
|
// First check for existing sessions to see if we know where to send this packet
|
|
s.wgSessions.Range(func(k, v interface{}) bool {
|
|
session := v.(*WireGuardSession)
|
|
if session.SenderIndex == receiverIndex {
|
|
// Found matching session
|
|
destAddr = session.DestAddr
|
|
|
|
// Update last seen time
|
|
session.LastSeen = time.Now()
|
|
s.wgSessions.Store(k, session)
|
|
return false // stop iteration
|
|
}
|
|
return true // continue iteration
|
|
})
|
|
|
|
if destAddr != nil {
|
|
// We found a specific peer to forward to
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
return
|
|
}
|
|
|
|
// Track communication pattern for session rebuilding
|
|
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Debug("Failed to forward transport data: %v", err)
|
|
}
|
|
} else {
|
|
// No known session, fall back to forwarding to all peers
|
|
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Track communication pattern for session rebuilding
|
|
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Debug("Failed to forward transport data: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
default:
|
|
// Other packet types (like cookie reply)
|
|
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
|
|
|
|
// Forward to all peers
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Error("Failed to forward WireGuard packet: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
key := destAddr.String() + "-" + remoteAddr.String()
|
|
|
|
// Check if we have an existing connection
|
|
if conn, ok := s.connections.Load(key); ok {
|
|
destConn := conn.(*DestinationConn)
|
|
destConn.lastUsed = time.Now()
|
|
return destConn.conn, nil
|
|
}
|
|
|
|
// Create new connection
|
|
newConn, err := net.DialUDP("udp", nil, destAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
|
|
}
|
|
|
|
// Store the new connection
|
|
s.connections.Store(key, &DestinationConn{
|
|
conn: newConn,
|
|
lastUsed: time.Now(),
|
|
})
|
|
|
|
// Start a goroutine to handle responses
|
|
go s.handleResponses(newConn, destAddr, remoteAddr)
|
|
|
|
return newConn, nil
|
|
}
|
|
|
|
func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) {
|
|
buffer := make([]byte, 1500)
|
|
for {
|
|
n, err := conn.Read(buffer)
|
|
if err != nil {
|
|
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
|
|
return
|
|
}
|
|
|
|
// Process the response to track sessions if it's a WireGuard packet
|
|
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
|
|
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
|
|
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
|
|
// Store the session mapping for the handshake response
|
|
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
|
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
|
ReceiverIndex: receiverIndex,
|
|
SenderIndex: senderIndex,
|
|
DestAddr: destAddr,
|
|
LastSeen: time.Now(),
|
|
})
|
|
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
|
|
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
|
|
// Track communication pattern for session rebuilding (reverse direction)
|
|
s.trackCommunicationPattern(destAddr, remoteAddr, receiverIndex, false)
|
|
}
|
|
}
|
|
|
|
// Forward the response back through the main listener
|
|
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to forward response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add a cleanup method to periodically remove idle connections
|
|
func (s *UDPProxyServer) cleanupIdleConnections() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.connections.Range(func(key, value interface{}) bool {
|
|
destConn := value.(*DestinationConn)
|
|
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
|
destConn.conn.Close()
|
|
s.connections.Delete(key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// New method to periodically remove idle sessions
|
|
func (s *UDPProxyServer) cleanupIdleSessions() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.wgSessions.Range(func(key, value interface{}) bool {
|
|
session := value.(*WireGuardSession)
|
|
if now.Sub(session.LastSeen) > 15*time.Minute {
|
|
s.wgSessions.Delete(key)
|
|
logger.Debug("Removed idle session: %s", key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// New method to periodically remove idle proxy mappings
|
|
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
|
ticker := time.NewTicker(10 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.proxyMappings.Range(func(key, value interface{}) bool {
|
|
mapping := value.(ProxyMapping)
|
|
// Remove mappings that haven't been used in 30 minutes
|
|
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
|
s.proxyMappings.Delete(key)
|
|
logger.Debug("Removed idle proxy mapping: %s", key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
|
logger.Debug("notifyServer called with endpoint: IP=%s, Port=%d", endpoint.IP, endpoint.Port)
|
|
|
|
jsonData, err := json.Marshal(endpoint)
|
|
if err != nil {
|
|
logger.Error("Failed to marshal endpoint data: %v", err)
|
|
return
|
|
}
|
|
|
|
resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
logger.Error("Failed to notify server: %v", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
logger.Error("Server returned non-OK status: %d, body: %s",
|
|
resp.StatusCode, string(body))
|
|
return
|
|
}
|
|
|
|
// Parse the proxy mapping response
|
|
var mapping ProxyMapping
|
|
if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil {
|
|
logger.Error("Failed to decode proxy mapping: %v", err)
|
|
return
|
|
}
|
|
|
|
logger.Debug("Received proxy mapping from server: %v", mapping)
|
|
|
|
// Store the mapping with current timestamp
|
|
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
|
|
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
|
|
mapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(key, mapping)
|
|
|
|
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
|
|
}
|
|
|
|
// Updated to support multiple destinations
|
|
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
|
|
key := fmt.Sprintf("%s:%d", sourceIP, sourcePort)
|
|
mapping := ProxyMapping{
|
|
Destinations: destinations,
|
|
LastUsed: time.Now(),
|
|
}
|
|
s.proxyMappings.Store(key, mapping)
|
|
}
|
|
|
|
// OnPeerAdded clears connections and sessions for a specific WireGuard IP to allow re-establishment
|
|
func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
|
|
logger.Info("Clearing connections for added peer with WG IP: %s", wgIP)
|
|
s.clearConnectionsForWGIP(wgIP)
|
|
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
|
|
// s.clearProxyMappingsForWGIP(wgIP)
|
|
}
|
|
|
|
// OnPeerRemoved clears connections and sessions for a specific WireGuard IP
|
|
func (s *UDPProxyServer) OnPeerRemoved(wgIP string) {
|
|
logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP)
|
|
s.clearConnectionsForWGIP(wgIP)
|
|
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
|
|
// s.clearProxyMappingsForWGIP(wgIP)
|
|
}
|
|
|
|
// clearConnectionsForWGIP removes all connections associated with a specific WireGuard IP
|
|
func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) {
|
|
var keysToDelete []string
|
|
|
|
s.connections.Range(func(key, value interface{}) bool {
|
|
keyStr := key.(string)
|
|
destConn := value.(*DestinationConn)
|
|
|
|
// Connection keys are in format "destAddr-remoteAddr"
|
|
// Check if either destination or remote address contains the WG IP
|
|
if containsIP(keyStr, wgIP) {
|
|
keysToDelete = append(keysToDelete, keyStr)
|
|
destConn.conn.Close()
|
|
logger.Debug("Closing connection for WG IP %s: %s", wgIP, keyStr)
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Delete the connections
|
|
for _, key := range keysToDelete {
|
|
s.connections.Delete(key)
|
|
}
|
|
|
|
logger.Info("Cleared %d connections for WG IP: %s", len(keysToDelete), wgIP)
|
|
}
|
|
|
|
// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP
|
|
func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
|
var keysToDelete []string
|
|
|
|
s.wgSessions.Range(func(key, value interface{}) bool {
|
|
keyStr := key.(string)
|
|
session := value.(*WireGuardSession)
|
|
|
|
// Check if the session's destination address contains the WG IP
|
|
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
|
keysToDelete = append(keysToDelete, keyStr)
|
|
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Delete the sessions
|
|
for _, key := range keysToDelete {
|
|
s.wgSessions.Delete(key)
|
|
}
|
|
|
|
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
|
|
}
|
|
|
|
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP
|
|
// func (s *UDPProxyServer) clearProxyMappingsForWGIP(wgIP string) {
|
|
// var keysToDelete []string
|
|
|
|
// s.proxyMappings.Range(func(key, value interface{}) bool {
|
|
// keyStr := key.(string)
|
|
// mapping := value.(ProxyMapping)
|
|
|
|
// // Check if any destination in the mapping contains the WG IP
|
|
// for _, dest := range mapping.Destinations {
|
|
// if dest.DestinationIP == wgIP {
|
|
// keysToDelete = append(keysToDelete, keyStr)
|
|
// logger.Debug("Marking proxy mapping for deletion for WG IP %s: %s -> %s:%d", wgIP, keyStr, dest.DestinationIP, dest.DestinationPort)
|
|
// break // Found one destination, no need to check others in this mapping
|
|
// }
|
|
// }
|
|
// return true
|
|
// })
|
|
|
|
// // Delete the proxy mappings
|
|
// for _, key := range keysToDelete {
|
|
// s.proxyMappings.Delete(key)
|
|
// logger.Debug("Deleted proxy mapping: %s", key)
|
|
// }
|
|
|
|
// logger.Info("Cleared %d proxy mappings for WG IP: %s", len(keysToDelete), wgIP)
|
|
// }
|
|
|
|
// containsIP checks if a connection key string contains the specified IP address
|
|
func containsIP(connectionKey, ip string) bool {
|
|
// Connection keys are in format "destIP:destPort-remoteIP:remotePort"
|
|
// Check if the IP appears at the beginning (destination) or after the dash (remote)
|
|
ipWithColon := ip + ":"
|
|
|
|
// Check if connection key starts with the IP (destination address)
|
|
if len(connectionKey) >= len(ipWithColon) && connectionKey[:len(ipWithColon)] == ipWithColon {
|
|
return true
|
|
}
|
|
|
|
// Check if connection key contains the IP after a dash (remote address)
|
|
dashIndex := -1
|
|
for i := 0; i < len(connectionKey); i++ {
|
|
if connectionKey[i] == '-' {
|
|
dashIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if dashIndex != -1 && dashIndex+1 < len(connectionKey) {
|
|
remainingPart := connectionKey[dashIndex+1:]
|
|
if len(remainingPart) >= len(ip)+1 && remainingPart[:len(ip)+1] == ipWithColon {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// UpdateDestinationInMappings updates all proxy mappings that contain the old destination with the new destination
|
|
// Returns the number of mappings that were updated
|
|
func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestination) int {
|
|
updatedCount := 0
|
|
|
|
s.proxyMappings.Range(func(key, value interface{}) bool {
|
|
keyStr := key.(string)
|
|
mapping := value.(ProxyMapping)
|
|
updated := false
|
|
|
|
// Check each destination in the mapping
|
|
for i, dest := range mapping.Destinations {
|
|
if dest.DestinationIP == oldDest.DestinationIP && dest.DestinationPort == oldDest.DestinationPort {
|
|
// Update this destination
|
|
mapping.Destinations[i] = newDest
|
|
updated = true
|
|
logger.Debug("Updated destination in mapping %s: %s:%d -> %s:%d",
|
|
keyStr, oldDest.DestinationIP, oldDest.DestinationPort,
|
|
newDest.DestinationIP, newDest.DestinationPort)
|
|
}
|
|
}
|
|
|
|
// If we updated any destinations, store the updated mapping back
|
|
if updated {
|
|
mapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(keyStr, mapping)
|
|
updatedCount++
|
|
}
|
|
|
|
return true // continue iteration
|
|
})
|
|
|
|
if updatedCount > 0 {
|
|
logger.Info("Updated %d proxy mappings from %s:%d to %s:%d",
|
|
updatedCount, oldDest.DestinationIP, oldDest.DestinationPort,
|
|
newDest.DestinationIP, newDest.DestinationPort)
|
|
}
|
|
|
|
return updatedCount
|
|
}
|
|
|
|
// trackCommunicationPattern tracks bidirectional communication patterns to rebuild sessions
|
|
func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr, receiverIndex uint32, fromClient bool) {
|
|
var clientAddr, destAddr *net.UDPAddr
|
|
var clientIndex, destIndex uint32
|
|
|
|
if fromClient {
|
|
clientAddr = fromAddr
|
|
destAddr = toAddr
|
|
clientIndex = receiverIndex
|
|
destIndex = 0 // We don't know the destination index yet
|
|
} else {
|
|
clientAddr = toAddr
|
|
destAddr = fromAddr
|
|
clientIndex = 0 // We don't know the client index yet
|
|
destIndex = receiverIndex
|
|
}
|
|
|
|
patternKey := fmt.Sprintf("%s-%s", clientAddr.String(), destAddr.String())
|
|
now := time.Now()
|
|
|
|
if existingPattern, ok := s.commPatterns.Load(patternKey); ok {
|
|
pattern := existingPattern.(*CommunicationPattern)
|
|
|
|
// Update the pattern
|
|
if fromClient {
|
|
pattern.LastFromClient = now
|
|
if pattern.ClientIndex == 0 {
|
|
pattern.ClientIndex = clientIndex
|
|
}
|
|
} else {
|
|
pattern.LastFromDest = now
|
|
if pattern.DestIndex == 0 {
|
|
pattern.DestIndex = destIndex
|
|
}
|
|
}
|
|
|
|
pattern.PacketCount++
|
|
s.commPatterns.Store(patternKey, pattern)
|
|
|
|
// Check if we have bidirectional communication and can rebuild a session
|
|
s.tryRebuildSession(pattern)
|
|
} else {
|
|
// Create new pattern
|
|
pattern := &CommunicationPattern{
|
|
FromClient: clientAddr,
|
|
ToDestination: destAddr,
|
|
ClientIndex: clientIndex,
|
|
DestIndex: destIndex,
|
|
PacketCount: 1,
|
|
}
|
|
|
|
if fromClient {
|
|
pattern.LastFromClient = now
|
|
} else {
|
|
pattern.LastFromDest = now
|
|
}
|
|
|
|
s.commPatterns.Store(patternKey, pattern)
|
|
}
|
|
}
|
|
|
|
// tryRebuildSession attempts to rebuild a WireGuard session from communication patterns
|
|
func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
|
// Check if we have bidirectional communication within a reasonable time window
|
|
timeDiff := pattern.LastFromClient.Sub(pattern.LastFromDest)
|
|
if timeDiff < 0 {
|
|
timeDiff = -timeDiff
|
|
}
|
|
|
|
// Only rebuild if we have recent bidirectional communication and both indices
|
|
if timeDiff < 30*time.Second && pattern.ClientIndex != 0 && pattern.DestIndex != 0 && pattern.PacketCount >= 4 {
|
|
// Create session mapping: client's index maps to destination
|
|
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
|
|
|
|
// Check if we already have this session
|
|
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
|
session := &WireGuardSession{
|
|
ReceiverIndex: pattern.DestIndex,
|
|
SenderIndex: pattern.ClientIndex,
|
|
DestAddr: pattern.ToDestination,
|
|
LastSeen: time.Now(),
|
|
}
|
|
|
|
s.wgSessions.Store(sessionKey, session)
|
|
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
|
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
|
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
|
ticker := time.NewTicker(10 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.commPatterns.Range(func(key, value interface{}) bool {
|
|
pattern := value.(*CommunicationPattern)
|
|
|
|
// Get the most recent activity
|
|
lastActivity := pattern.LastFromClient
|
|
if pattern.LastFromDest.After(lastActivity) {
|
|
lastActivity = pattern.LastFromDest
|
|
}
|
|
|
|
// Remove patterns that haven't had activity in 20 minutes
|
|
if now.Sub(lastActivity) > 20*time.Minute {
|
|
s.commPatterns.Delete(key)
|
|
logger.Debug("Removed idle communication pattern: %s", key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|