mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
The race condition existed because while sync.Map is thread-safe for map operations (Load, Store, Delete, Range), it does not provide thread-safety for the data stored within it. When WireGuardSession structs were stored as pointers in the sync.Map, multiple goroutines could: 1. Retrieve the same session pointer from the map concurrently 2. Access and modify the session's fields (particularly LastSeen) without synchronization 3. Cause data races when one goroutine reads LastSeen while another updates it This fix adds a sync.RWMutex to each WireGuardSession struct to protect concurrent access to its fields. All field access now goes through thread-safe methods that properly acquire/release the mutex. Changes: - Added sync.RWMutex to WireGuardSession struct - Added thread-safe accessor methods (GetLastSeen, GetDestAddr, etc.) - Added atomic CheckAndUpdateIfMatch method for efficient check-and-update - Updated all session field accesses to use thread-safe methods - Removed redundant Store call after updating LastSeen (pointer update is atomic in Go, but field access within pointer was not)
1014 lines
32 KiB
Go
1014 lines
32 KiB
Go
package relay
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fosrl/gerbil/logger"
|
|
"golang.org/x/crypto/chacha20poly1305"
|
|
"golang.org/x/crypto/curve25519"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
)
|
|
|
|
type EncryptedHolePunchMessage struct {
|
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
Nonce []byte `json:"nonce"`
|
|
Ciphertext []byte `json:"ciphertext"`
|
|
}
|
|
|
|
type HolePunchMessage struct {
|
|
OlmID string `json:"olmId"`
|
|
NewtID string `json:"newtId"`
|
|
Token string `json:"token"`
|
|
}
|
|
|
|
type ClientEndpoint struct {
|
|
OlmID string `json:"olmId"`
|
|
NewtID string `json:"newtId"`
|
|
Token string `json:"token"`
|
|
IP string `json:"ip"`
|
|
Port int `json:"port"`
|
|
Timestamp int64 `json:"timestamp"`
|
|
ReachableAt string `json:"reachableAt"`
|
|
PublicKey string `json:"publicKey"`
|
|
}
|
|
|
|
// Updated to support multiple destination peers
|
|
type ProxyMapping struct {
|
|
Destinations []PeerDestination `json:"destinations"`
|
|
LastUsed time.Time `json:"-"` // Not serialized, used for cleanup
|
|
}
|
|
|
|
type PeerDestination struct {
|
|
DestinationIP string `json:"destinationIP"`
|
|
DestinationPort int `json:"destinationPort"`
|
|
}
|
|
|
|
type DestinationConn struct {
|
|
conn *net.UDPConn
|
|
lastUsed time.Time
|
|
}
|
|
|
|
// Type for storing WireGuard handshake information
|
|
type WireGuardSession struct {
|
|
mu sync.RWMutex
|
|
ReceiverIndex uint32
|
|
SenderIndex uint32
|
|
DestAddr *net.UDPAddr
|
|
LastSeen time.Time
|
|
}
|
|
|
|
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
|
|
func (s *WireGuardSession) UpdateLastSeen() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.LastSeen = time.Now()
|
|
}
|
|
|
|
// GetSenderIndex returns the SenderIndex in a thread-safe manner
|
|
func (s *WireGuardSession) GetSenderIndex() uint32 {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.SenderIndex
|
|
}
|
|
|
|
// GetDestAddr returns the DestAddr in a thread-safe manner
|
|
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.DestAddr
|
|
}
|
|
|
|
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
|
|
func (s *WireGuardSession) GetLastSeen() time.Time {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.LastSeen
|
|
}
|
|
|
|
// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner
|
|
func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.SenderIndex == receiverIndex
|
|
}
|
|
|
|
// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does.
|
|
// Returns the DestAddr and true if there's a match, nil and false otherwise.
|
|
// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls.
|
|
func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.SenderIndex == receiverIndex {
|
|
s.LastSeen = time.Now()
|
|
return s.DestAddr, true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
// Type for tracking bidirectional communication patterns to rebuild sessions
|
|
type CommunicationPattern struct {
|
|
FromClient *net.UDPAddr // The client address
|
|
ToDestination *net.UDPAddr // The destination address
|
|
ClientIndex uint32 // The receiver index seen from client
|
|
DestIndex uint32 // The receiver index seen from destination
|
|
LastFromClient time.Time // Last packet from client to destination
|
|
LastFromDest time.Time // Last packet from destination to client
|
|
PacketCount int // Number of packets observed
|
|
}
|
|
|
|
type InitialMappings struct {
|
|
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
|
}
|
|
|
|
// Packet is a simple struct to hold the packet data and sender info.
|
|
type Packet struct {
|
|
data []byte
|
|
remoteAddr *net.UDPAddr
|
|
n int
|
|
}
|
|
|
|
// WireGuard message types
|
|
const (
|
|
WireGuardMessageTypeHandshakeInitiation = 1
|
|
WireGuardMessageTypeHandshakeResponse = 2
|
|
WireGuardMessageTypeCookieReply = 3
|
|
WireGuardMessageTypeTransportData = 4
|
|
)
|
|
|
|
// --- End Types ---
|
|
|
|
// bufferPool allows reusing buffers to reduce allocations.
|
|
var bufferPool = sync.Pool{
|
|
New: func() interface{} {
|
|
return make([]byte, 1500)
|
|
},
|
|
}
|
|
|
|
// UDPProxyServer has a channel for incoming packets.
|
|
type UDPProxyServer struct {
|
|
addr string
|
|
serverURL string
|
|
conn *net.UDPConn
|
|
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
|
|
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
|
privateKey wgtypes.Key
|
|
packetChan chan Packet
|
|
|
|
// Session tracking for WireGuard peers
|
|
// Key format: "senderIndex:receiverIndex"
|
|
wgSessions sync.Map
|
|
// Communication pattern tracking for rebuilding sessions
|
|
// Key format: "clientIP:clientPort-destIP:destPort"
|
|
commPatterns sync.Map
|
|
// ReachableAt is the URL where this server can be reached
|
|
ReachableAt string
|
|
}
|
|
|
|
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
|
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
|
return &UDPProxyServer{
|
|
addr: addr,
|
|
serverURL: serverURL,
|
|
privateKey: privateKey,
|
|
packetChan: make(chan Packet, 1000),
|
|
ReachableAt: reachableAt,
|
|
}
|
|
}
|
|
|
|
// Start sets up the UDP listener, worker pool, and begins reading packets.
|
|
func (s *UDPProxyServer) Start() error {
|
|
// Fetch initial mappings.
|
|
if err := s.fetchInitialMappings(); err != nil {
|
|
return fmt.Errorf("failed to fetch initial mappings: %v", err)
|
|
}
|
|
|
|
udpAddr, err := net.ResolveUDPAddr("udp", s.addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conn, err := net.ListenUDP("udp", udpAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
s.conn = conn
|
|
logger.Info("UDP server listening on %s", s.addr)
|
|
|
|
// Start a fixed number of worker goroutines.
|
|
workerCount := 10 // TODO: Make this configurable or pick it better!
|
|
for i := 0; i < workerCount; i++ {
|
|
go s.packetWorker()
|
|
}
|
|
|
|
// Start the goroutine that reads packets from the UDP socket.
|
|
go s.readPackets()
|
|
|
|
// Start the idle connection cleanup routine.
|
|
go s.cleanupIdleConnections()
|
|
|
|
// Start the session cleanup routine
|
|
go s.cleanupIdleSessions()
|
|
|
|
// Start the proxy mapping cleanup routine
|
|
go s.cleanupIdleProxyMappings()
|
|
|
|
// Start the communication pattern cleanup routine
|
|
go s.cleanupIdleCommunicationPatterns()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *UDPProxyServer) Stop() {
|
|
s.conn.Close()
|
|
}
|
|
|
|
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
|
func (s *UDPProxyServer) readPackets() {
|
|
for {
|
|
buf := bufferPool.Get().([]byte)
|
|
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
logger.Error("Error reading UDP packet: %v", err)
|
|
continue
|
|
}
|
|
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
|
}
|
|
}
|
|
|
|
// packetWorker processes incoming packets from the channel.
|
|
func (s *UDPProxyServer) packetWorker() {
|
|
for packet := range s.packetChan {
|
|
// Determine packet type by inspecting the first byte.
|
|
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
|
|
// Process as a WireGuard packet.
|
|
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
|
|
} else {
|
|
// Process as an encrypted hole punch message
|
|
var encMsg EncryptedHolePunchMessage
|
|
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
|
|
logger.Error("Error unmarshaling encrypted message: %v", err)
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
if encMsg.EphemeralPublicKey == "" {
|
|
logger.Error("Received malformed message without ephemeral key")
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
// This appears to be an encrypted message
|
|
decryptedData, err := s.decryptMessage(encMsg)
|
|
if err != nil {
|
|
logger.Error("Failed to decrypt message: %v", err)
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
// Process the decrypted hole punch message
|
|
var msg HolePunchMessage
|
|
if err := json.Unmarshal(decryptedData, &msg); err != nil {
|
|
logger.Error("Error unmarshaling decrypted message: %v", err)
|
|
// Return the buffer to the pool for reuse and continue with next packet
|
|
bufferPool.Put(packet.data[:1500])
|
|
continue
|
|
}
|
|
|
|
endpoint := ClientEndpoint{
|
|
NewtID: msg.NewtID,
|
|
OlmID: msg.OlmID,
|
|
Token: msg.Token,
|
|
IP: packet.remoteAddr.IP.String(),
|
|
Port: packet.remoteAddr.Port,
|
|
Timestamp: time.Now().Unix(),
|
|
ReachableAt: s.ReachableAt,
|
|
PublicKey: s.privateKey.PublicKey().String(),
|
|
}
|
|
logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port)
|
|
s.notifyServer(endpoint)
|
|
s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment
|
|
}
|
|
// Return the buffer to the pool for reuse.
|
|
bufferPool.Put(packet.data[:1500])
|
|
}
|
|
}
|
|
|
|
// decryptMessage decrypts the message using the server's private key
|
|
func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) {
|
|
// Parse the ephemeral public key
|
|
ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err)
|
|
}
|
|
|
|
// Use X25519 for key exchange instead of ScalarMult
|
|
sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
|
}
|
|
|
|
// Create the AEAD cipher using the shared secret
|
|
aead, err := chacha20poly1305.New(sharedSecret)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
|
}
|
|
|
|
// Verify nonce size
|
|
if len(encMsg.Nonce) != aead.NonceSize() {
|
|
return nil, fmt.Errorf("invalid nonce size")
|
|
}
|
|
|
|
// Decrypt the ciphertext
|
|
plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt message: %v", err)
|
|
}
|
|
|
|
return plaintext, nil
|
|
}
|
|
|
|
func (s *UDPProxyServer) fetchInitialMappings() error {
|
|
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String())))
|
|
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch mappings: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("server returned non-OK status: %d, body: %s",
|
|
resp.StatusCode, string(body))
|
|
}
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read response body: %v", err)
|
|
}
|
|
logger.Info("Received initial mappings: %s", string(data))
|
|
var initialMappings InitialMappings
|
|
if err := json.Unmarshal(data, &initialMappings); err != nil {
|
|
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
|
|
}
|
|
// Store mappings in our sync.Map.
|
|
for key, mapping := range initialMappings.Mappings {
|
|
// Initialize LastUsed timestamp for initial mappings
|
|
mapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(key, mapping)
|
|
}
|
|
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
|
|
return nil
|
|
}
|
|
|
|
// Extract WireGuard message indices
|
|
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
|
if len(packet) < 12 {
|
|
return 0, 0, false
|
|
}
|
|
|
|
messageType := packet[0]
|
|
if messageType == WireGuardMessageTypeHandshakeInitiation {
|
|
// Handshake initiation: extract sender index at offset 4
|
|
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
|
return 0, senderIndex, true
|
|
} else if messageType == WireGuardMessageTypeHandshakeResponse {
|
|
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
|
|
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
|
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
|
|
return receiverIndex, senderIndex, true
|
|
} else if messageType == WireGuardMessageTypeTransportData {
|
|
// Transport data: extract receiver index at offset 4
|
|
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
|
|
return receiverIndex, 0, true
|
|
}
|
|
|
|
return 0, 0, false
|
|
}
|
|
|
|
// Updated to handle multi-peer WireGuard communication
|
|
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
|
if len(packet) == 0 {
|
|
logger.Error("Received empty packet")
|
|
return
|
|
}
|
|
|
|
messageType := packet[0]
|
|
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
|
|
|
|
if !ok {
|
|
logger.Error("Failed to extract WireGuard indices")
|
|
return
|
|
}
|
|
|
|
key := remoteAddr.String()
|
|
mappingObj, ok := s.proxyMappings.Load(key)
|
|
if !ok {
|
|
logger.Error("No proxy mapping found for %s", key)
|
|
return
|
|
}
|
|
|
|
proxyMapping := mappingObj.(ProxyMapping)
|
|
// Update the last used timestamp and store it back
|
|
proxyMapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(key, proxyMapping)
|
|
|
|
// Handle different WireGuard message types
|
|
switch messageType {
|
|
case WireGuardMessageTypeHandshakeInitiation:
|
|
// Initial handshake: forward to all peers
|
|
logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations)
|
|
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Error("Failed to forward handshake initiation: %v", err)
|
|
}
|
|
}
|
|
|
|
case WireGuardMessageTypeHandshakeResponse:
|
|
// Received handshake response: establish session mapping
|
|
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
|
|
receiverIndex, senderIndex, remoteAddr)
|
|
|
|
// Create a session key for the peer that sent the initial handshake
|
|
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
|
|
|
|
// Store the session information
|
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
|
ReceiverIndex: receiverIndex,
|
|
SenderIndex: senderIndex,
|
|
DestAddr: remoteAddr,
|
|
LastSeen: time.Now(),
|
|
})
|
|
|
|
// Forward the response to the original sender
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Error("Failed to forward handshake response: %v", err)
|
|
}
|
|
}
|
|
|
|
case WireGuardMessageTypeTransportData:
|
|
// Data packet: forward only to the established session peer
|
|
// logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
|
|
|
|
// Look up the session based on the receiver index
|
|
var destAddr *net.UDPAddr
|
|
|
|
// First check for existing sessions to see if we know where to send this packet
|
|
s.wgSessions.Range(func(k, v interface{}) bool {
|
|
session := v.(*WireGuardSession)
|
|
// Atomically check if session matches and update LastSeen if it does
|
|
if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches {
|
|
// Found matching session
|
|
destAddr = addr
|
|
return false // stop iteration
|
|
}
|
|
return true // continue iteration
|
|
})
|
|
|
|
if destAddr != nil {
|
|
// We found a specific peer to forward to
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
return
|
|
}
|
|
|
|
// Track communication pattern for session rebuilding
|
|
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Debug("Failed to forward transport data: %v", err)
|
|
}
|
|
} else {
|
|
// No known session, fall back to forwarding to all peers
|
|
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Track communication pattern for session rebuilding
|
|
s.trackCommunicationPattern(remoteAddr, destAddr, receiverIndex, true)
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Debug("Failed to forward transport data: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
default:
|
|
// Other packet types (like cookie reply)
|
|
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
|
|
|
|
// Forward to all peers
|
|
for _, dest := range proxyMapping.Destinations {
|
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
|
if err != nil {
|
|
logger.Error("Failed to resolve destination address: %v", err)
|
|
continue
|
|
}
|
|
|
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to get/create connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
_, err = conn.Write(packet)
|
|
if err != nil {
|
|
logger.Error("Failed to forward WireGuard packet: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
key := destAddr.String() + "-" + remoteAddr.String()
|
|
|
|
// Check if we have an existing connection
|
|
if conn, ok := s.connections.Load(key); ok {
|
|
destConn := conn.(*DestinationConn)
|
|
destConn.lastUsed = time.Now()
|
|
return destConn.conn, nil
|
|
}
|
|
|
|
// Create new connection
|
|
newConn, err := net.DialUDP("udp", nil, destAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
|
|
}
|
|
|
|
// Store the new connection
|
|
s.connections.Store(key, &DestinationConn{
|
|
conn: newConn,
|
|
lastUsed: time.Now(),
|
|
})
|
|
|
|
// Start a goroutine to handle responses
|
|
go s.handleResponses(newConn, destAddr, remoteAddr)
|
|
|
|
return newConn, nil
|
|
}
|
|
|
|
func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) {
|
|
buffer := make([]byte, 1500)
|
|
for {
|
|
n, err := conn.Read(buffer)
|
|
if err != nil {
|
|
logger.Debug("Error reading response from %s: %v", destAddr.String(), err)
|
|
return
|
|
}
|
|
|
|
// Process the response to track sessions if it's a WireGuard packet
|
|
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
|
|
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
|
|
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
|
|
// Store the session mapping for the handshake response
|
|
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
|
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
|
ReceiverIndex: receiverIndex,
|
|
SenderIndex: senderIndex,
|
|
DestAddr: destAddr,
|
|
LastSeen: time.Now(),
|
|
})
|
|
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
|
|
} else if ok && buffer[0] == WireGuardMessageTypeTransportData {
|
|
// Track communication pattern for session rebuilding (reverse direction)
|
|
s.trackCommunicationPattern(destAddr, remoteAddr, receiverIndex, false)
|
|
}
|
|
}
|
|
|
|
// Forward the response back through the main listener
|
|
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to forward response: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add a cleanup method to periodically remove idle connections
|
|
func (s *UDPProxyServer) cleanupIdleConnections() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.connections.Range(func(key, value interface{}) bool {
|
|
destConn := value.(*DestinationConn)
|
|
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
|
destConn.conn.Close()
|
|
s.connections.Delete(key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// New method to periodically remove idle sessions
|
|
func (s *UDPProxyServer) cleanupIdleSessions() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.wgSessions.Range(func(key, value interface{}) bool {
|
|
session := value.(*WireGuardSession)
|
|
// Use thread-safe method to read LastSeen
|
|
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
|
s.wgSessions.Delete(key)
|
|
logger.Debug("Removed idle session: %s", key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// New method to periodically remove idle proxy mappings
|
|
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
|
ticker := time.NewTicker(10 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.proxyMappings.Range(func(key, value interface{}) bool {
|
|
mapping := value.(ProxyMapping)
|
|
// Remove mappings that haven't been used in 30 minutes
|
|
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
|
s.proxyMappings.Delete(key)
|
|
logger.Debug("Removed idle proxy mapping: %s", key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
|
logger.Debug("notifyServer called with endpoint: IP=%s, Port=%d", endpoint.IP, endpoint.Port)
|
|
|
|
jsonData, err := json.Marshal(endpoint)
|
|
if err != nil {
|
|
logger.Error("Failed to marshal endpoint data: %v", err)
|
|
return
|
|
}
|
|
|
|
resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
logger.Error("Failed to notify server: %v", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
logger.Error("Server returned non-OK status: %d, body: %s",
|
|
resp.StatusCode, string(body))
|
|
return
|
|
}
|
|
|
|
// Parse the proxy mapping response
|
|
var mapping ProxyMapping
|
|
if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil {
|
|
logger.Error("Failed to decode proxy mapping: %v", err)
|
|
return
|
|
}
|
|
|
|
logger.Debug("Received proxy mapping from server: %v", mapping)
|
|
|
|
// Store the mapping with current timestamp
|
|
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
|
|
logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port)
|
|
mapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(key, mapping)
|
|
|
|
logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed)
|
|
}
|
|
|
|
// Updated to support multiple destinations
|
|
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
|
|
key := fmt.Sprintf("%s:%d", sourceIP, sourcePort)
|
|
mapping := ProxyMapping{
|
|
Destinations: destinations,
|
|
LastUsed: time.Now(),
|
|
}
|
|
s.proxyMappings.Store(key, mapping)
|
|
}
|
|
|
|
// OnPeerAdded clears connections and sessions for a specific WireGuard IP to allow re-establishment
|
|
func (s *UDPProxyServer) OnPeerAdded(wgIP string) {
|
|
logger.Info("Clearing connections for added peer with WG IP: %s", wgIP)
|
|
s.clearConnectionsForWGIP(wgIP)
|
|
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
|
|
// s.clearProxyMappingsForWGIP(wgIP)
|
|
}
|
|
|
|
// OnPeerRemoved clears connections and sessions for a specific WireGuard IP
|
|
func (s *UDPProxyServer) OnPeerRemoved(wgIP string) {
|
|
logger.Info("Clearing connections for removed peer with WG IP: %s", wgIP)
|
|
s.clearConnectionsForWGIP(wgIP)
|
|
// s.clearSessionsForWGIP(wgIP) THE DEST ADDR IS NOT THE WG IP, SO THIS IS NOT NEEDED
|
|
// s.clearProxyMappingsForWGIP(wgIP)
|
|
}
|
|
|
|
// clearConnectionsForWGIP removes all connections associated with a specific WireGuard IP
|
|
func (s *UDPProxyServer) clearConnectionsForWGIP(wgIP string) {
|
|
var keysToDelete []string
|
|
|
|
s.connections.Range(func(key, value interface{}) bool {
|
|
keyStr := key.(string)
|
|
destConn := value.(*DestinationConn)
|
|
|
|
// Connection keys are in format "destAddr-remoteAddr"
|
|
// Check if either destination or remote address contains the WG IP
|
|
if containsIP(keyStr, wgIP) {
|
|
keysToDelete = append(keysToDelete, keyStr)
|
|
destConn.conn.Close()
|
|
logger.Debug("Closing connection for WG IP %s: %s", wgIP, keyStr)
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Delete the connections
|
|
for _, key := range keysToDelete {
|
|
s.connections.Delete(key)
|
|
}
|
|
|
|
logger.Info("Cleared %d connections for WG IP: %s", len(keysToDelete), wgIP)
|
|
}
|
|
|
|
// clearSessionsForWGIP removes all WireGuard sessions associated with a specific WireGuard IP
|
|
func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
|
var keysToDelete []string
|
|
|
|
s.wgSessions.Range(func(key, value interface{}) bool {
|
|
keyStr := key.(string)
|
|
session := value.(*WireGuardSession)
|
|
|
|
// Check if the session's destination address contains the WG IP (thread-safe)
|
|
destAddr := session.GetDestAddr()
|
|
if destAddr != nil && destAddr.IP.String() == ip {
|
|
keysToDelete = append(keysToDelete, keyStr)
|
|
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Delete the sessions
|
|
for _, key := range keysToDelete {
|
|
s.wgSessions.Delete(key)
|
|
}
|
|
|
|
logger.Info("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip)
|
|
}
|
|
|
|
// // clearProxyMappingsForWGIP removes all proxy mappings that have destinations pointing to a specific WireGuard IP
|
|
// func (s *UDPProxyServer) clearProxyMappingsForWGIP(wgIP string) {
|
|
// var keysToDelete []string
|
|
|
|
// s.proxyMappings.Range(func(key, value interface{}) bool {
|
|
// keyStr := key.(string)
|
|
// mapping := value.(ProxyMapping)
|
|
|
|
// // Check if any destination in the mapping contains the WG IP
|
|
// for _, dest := range mapping.Destinations {
|
|
// if dest.DestinationIP == wgIP {
|
|
// keysToDelete = append(keysToDelete, keyStr)
|
|
// logger.Debug("Marking proxy mapping for deletion for WG IP %s: %s -> %s:%d", wgIP, keyStr, dest.DestinationIP, dest.DestinationPort)
|
|
// break // Found one destination, no need to check others in this mapping
|
|
// }
|
|
// }
|
|
// return true
|
|
// })
|
|
|
|
// // Delete the proxy mappings
|
|
// for _, key := range keysToDelete {
|
|
// s.proxyMappings.Delete(key)
|
|
// logger.Debug("Deleted proxy mapping: %s", key)
|
|
// }
|
|
|
|
// logger.Info("Cleared %d proxy mappings for WG IP: %s", len(keysToDelete), wgIP)
|
|
// }
|
|
|
|
// containsIP checks if a connection key string contains the specified IP address
|
|
func containsIP(connectionKey, ip string) bool {
|
|
// Connection keys are in format "destIP:destPort-remoteIP:remotePort"
|
|
// Check if the IP appears at the beginning (destination) or after the dash (remote)
|
|
ipWithColon := ip + ":"
|
|
|
|
// Check if connection key starts with the IP (destination address)
|
|
if len(connectionKey) >= len(ipWithColon) && connectionKey[:len(ipWithColon)] == ipWithColon {
|
|
return true
|
|
}
|
|
|
|
// Check if connection key contains the IP after a dash (remote address)
|
|
dashIndex := -1
|
|
for i := 0; i < len(connectionKey); i++ {
|
|
if connectionKey[i] == '-' {
|
|
dashIndex = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if dashIndex != -1 && dashIndex+1 < len(connectionKey) {
|
|
remainingPart := connectionKey[dashIndex+1:]
|
|
if len(remainingPart) >= len(ip)+1 && remainingPart[:len(ip)+1] == ipWithColon {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// UpdateDestinationInMappings updates all proxy mappings that contain the old destination with the new destination
|
|
// Returns the number of mappings that were updated
|
|
func (s *UDPProxyServer) UpdateDestinationInMappings(oldDest, newDest PeerDestination) int {
|
|
updatedCount := 0
|
|
|
|
s.proxyMappings.Range(func(key, value interface{}) bool {
|
|
keyStr := key.(string)
|
|
mapping := value.(ProxyMapping)
|
|
updated := false
|
|
|
|
// Check each destination in the mapping
|
|
for i, dest := range mapping.Destinations {
|
|
if dest.DestinationIP == oldDest.DestinationIP && dest.DestinationPort == oldDest.DestinationPort {
|
|
// Update this destination
|
|
mapping.Destinations[i] = newDest
|
|
updated = true
|
|
logger.Debug("Updated destination in mapping %s: %s:%d -> %s:%d",
|
|
keyStr, oldDest.DestinationIP, oldDest.DestinationPort,
|
|
newDest.DestinationIP, newDest.DestinationPort)
|
|
}
|
|
}
|
|
|
|
// If we updated any destinations, store the updated mapping back
|
|
if updated {
|
|
mapping.LastUsed = time.Now()
|
|
s.proxyMappings.Store(keyStr, mapping)
|
|
updatedCount++
|
|
}
|
|
|
|
return true // continue iteration
|
|
})
|
|
|
|
if updatedCount > 0 {
|
|
logger.Info("Updated %d proxy mappings from %s:%d to %s:%d",
|
|
updatedCount, oldDest.DestinationIP, oldDest.DestinationPort,
|
|
newDest.DestinationIP, newDest.DestinationPort)
|
|
}
|
|
|
|
return updatedCount
|
|
}
|
|
|
|
// trackCommunicationPattern tracks bidirectional communication patterns to rebuild sessions
|
|
func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr, receiverIndex uint32, fromClient bool) {
|
|
var clientAddr, destAddr *net.UDPAddr
|
|
var clientIndex, destIndex uint32
|
|
|
|
if fromClient {
|
|
clientAddr = fromAddr
|
|
destAddr = toAddr
|
|
clientIndex = receiverIndex
|
|
destIndex = 0 // We don't know the destination index yet
|
|
} else {
|
|
clientAddr = toAddr
|
|
destAddr = fromAddr
|
|
clientIndex = 0 // We don't know the client index yet
|
|
destIndex = receiverIndex
|
|
}
|
|
|
|
patternKey := fmt.Sprintf("%s-%s", clientAddr.String(), destAddr.String())
|
|
now := time.Now()
|
|
|
|
if existingPattern, ok := s.commPatterns.Load(patternKey); ok {
|
|
pattern := existingPattern.(*CommunicationPattern)
|
|
|
|
// Update the pattern
|
|
if fromClient {
|
|
pattern.LastFromClient = now
|
|
if pattern.ClientIndex == 0 {
|
|
pattern.ClientIndex = clientIndex
|
|
}
|
|
} else {
|
|
pattern.LastFromDest = now
|
|
if pattern.DestIndex == 0 {
|
|
pattern.DestIndex = destIndex
|
|
}
|
|
}
|
|
|
|
pattern.PacketCount++
|
|
s.commPatterns.Store(patternKey, pattern)
|
|
|
|
// Check if we have bidirectional communication and can rebuild a session
|
|
s.tryRebuildSession(pattern)
|
|
} else {
|
|
// Create new pattern
|
|
pattern := &CommunicationPattern{
|
|
FromClient: clientAddr,
|
|
ToDestination: destAddr,
|
|
ClientIndex: clientIndex,
|
|
DestIndex: destIndex,
|
|
PacketCount: 1,
|
|
}
|
|
|
|
if fromClient {
|
|
pattern.LastFromClient = now
|
|
} else {
|
|
pattern.LastFromDest = now
|
|
}
|
|
|
|
s.commPatterns.Store(patternKey, pattern)
|
|
}
|
|
}
|
|
|
|
// tryRebuildSession attempts to rebuild a WireGuard session from communication patterns
|
|
func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
|
// Check if we have bidirectional communication within a reasonable time window
|
|
timeDiff := pattern.LastFromClient.Sub(pattern.LastFromDest)
|
|
if timeDiff < 0 {
|
|
timeDiff = -timeDiff
|
|
}
|
|
|
|
// Only rebuild if we have recent bidirectional communication and both indices
|
|
if timeDiff < 30*time.Second && pattern.ClientIndex != 0 && pattern.DestIndex != 0 && pattern.PacketCount >= 4 {
|
|
// Create session mapping: client's index maps to destination
|
|
sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex)
|
|
|
|
// Check if we already have this session
|
|
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
|
session := &WireGuardSession{
|
|
ReceiverIndex: pattern.DestIndex,
|
|
SenderIndex: pattern.ClientIndex,
|
|
DestAddr: pattern.ToDestination,
|
|
LastSeen: time.Now(),
|
|
}
|
|
|
|
s.wgSessions.Store(sessionKey, session)
|
|
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
|
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
|
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
|
ticker := time.NewTicker(10 * time.Minute)
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
s.commPatterns.Range(func(key, value interface{}) bool {
|
|
pattern := value.(*CommunicationPattern)
|
|
|
|
// Get the most recent activity
|
|
lastActivity := pattern.LastFromClient
|
|
if pattern.LastFromDest.After(lastActivity) {
|
|
lastActivity = pattern.LastFromDest
|
|
}
|
|
|
|
// Remove patterns that haven't had activity in 20 minutes
|
|
if now.Sub(lastActivity) > 20*time.Minute {
|
|
s.commPatterns.Delete(key)
|
|
logger.Debug("Removed idle communication pattern: %s", key)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|