Files
gerbil/relay/relay.go

622 lines
18 KiB
Go

package relay
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type EncryptedHolePunchMessage struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}
type HolePunchMessage struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
Token string `json:"token"`
}
type ClientEndpoint struct {
OlmID string `json:"olmId"`
NewtID string `json:"newtId"`
Token string `json:"token"`
IP string `json:"ip"`
Port int `json:"port"`
Timestamp int64 `json:"timestamp"`
}
// Updated to support multiple destination peers
type ProxyMapping struct {
Destinations []PeerDestination `json:"destinations"`
}
type PeerDestination struct {
DestinationIP string `json:"destinationIP"`
DestinationPort int `json:"destinationPort"`
}
type DestinationConn struct {
conn *net.UDPConn
lastUsed time.Time
}
// Type for storing WireGuard handshake information
type WireGuardSession struct {
ReceiverIndex uint32
SenderIndex uint32
DestAddr *net.UDPAddr
LastSeen time.Time
}
type InitialMappings struct {
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
}
// Packet is a simple struct to hold the packet data and sender info.
type Packet struct {
data []byte
remoteAddr *net.UDPAddr
n int
}
// WireGuard message types
const (
WireGuardMessageTypeHandshakeInitiation = 1
WireGuardMessageTypeHandshakeResponse = 2
WireGuardMessageTypeCookieReply = 3
WireGuardMessageTypeTransportData = 4
)
// --- End Types ---
// bufferPool allows reusing buffers to reduce allocations.
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 1500)
},
}
// UDPProxyServer has a channel for incoming packets.
type UDPProxyServer struct {
addr string
serverURL string
conn *net.UDPConn
proxyMappings sync.Map // map[string]ProxyMapping where key is "ip:port"
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
privateKey wgtypes.Key
packetChan chan Packet
// Session tracking for WireGuard peers
// Key format: "senderIndex:receiverIndex"
wgSessions sync.Map
}
// NewUDPProxyServer initializes the server with a buffered packet channel.
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key) *UDPProxyServer {
return &UDPProxyServer{
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 1000),
}
}
// Start sets up the UDP listener, worker pool, and begins reading packets.
func (s *UDPProxyServer) Start() error {
// Fetch initial mappings.
if err := s.fetchInitialMappings(); err != nil {
return fmt.Errorf("failed to fetch initial mappings: %v", err)
}
udpAddr, err := net.ResolveUDPAddr("udp", s.addr)
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
s.conn = conn
logger.Info("UDP server listening on %s", s.addr)
// Start a fixed number of worker goroutines.
workerCount := 10 // TODO: Make this configurable or pick it better!
for i := 0; i < workerCount; i++ {
go s.packetWorker()
}
// Start the goroutine that reads packets from the UDP socket.
go s.readPackets()
// Start the idle connection cleanup routine.
go s.cleanupIdleConnections()
// Start the session cleanup routine
go s.cleanupIdleSessions()
return nil
}
func (s *UDPProxyServer) Stop() {
s.conn.Close()
}
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
logger.Error("Error reading UDP packet: %v", err)
continue
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
}
// packetWorker processes incoming packets from the channel.
func (s *UDPProxyServer) packetWorker() {
for packet := range s.packetChan {
// Determine packet type by inspecting the first byte.
if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 {
// Process as a WireGuard packet.
s.handleWireGuardPacket(packet.data, packet.remoteAddr)
} else {
// Process as an encrypted hole punch message
var encMsg EncryptedHolePunchMessage
if err := json.Unmarshal(packet.data, &encMsg); err != nil {
logger.Error("Error unmarshaling encrypted message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
if encMsg.EphemeralPublicKey == "" {
logger.Error("Received malformed message without ephemeral key")
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
// This appears to be an encrypted message
decryptedData, err := s.decryptMessage(encMsg)
if err != nil {
logger.Error("Failed to decrypt message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
// Process the decrypted hole punch message
var msg HolePunchMessage
if err := json.Unmarshal(decryptedData, &msg); err != nil {
logger.Error("Error unmarshaling decrypted message: %v", err)
// Return the buffer to the pool for reuse and continue with next packet
bufferPool.Put(packet.data[:1500])
continue
}
endpoint := ClientEndpoint{
NewtID: msg.NewtID,
OlmID: msg.OlmID,
Token: msg.Token,
IP: packet.remoteAddr.IP.String(),
Port: packet.remoteAddr.Port,
Timestamp: time.Now().Unix(),
}
s.notifyServer(endpoint)
}
// Return the buffer to the pool for reuse.
bufferPool.Put(packet.data[:1500])
}
}
// decryptMessage decrypts the message using the server's private key
func (s *UDPProxyServer) decryptMessage(encMsg EncryptedHolePunchMessage) ([]byte, error) {
// Parse the ephemeral public key
ephPubKey, err := wgtypes.ParseKey(encMsg.EphemeralPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse ephemeral public key: %v", err)
}
// Use X25519 for key exchange instead of ScalarMult
sharedSecret, err := curve25519.X25519(s.privateKey[:], ephPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create the AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Verify nonce size
if len(encMsg.Nonce) != aead.NonceSize() {
return nil, fmt.Errorf("invalid nonce size")
}
// Decrypt the ciphertext
plaintext, err := aead.Open(nil, encMsg.Nonce, encMsg.Ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt message: %v", err)
}
return plaintext, nil
}
func (s *UDPProxyServer) fetchInitialMappings() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, s.privateKey.PublicKey().String())))
resp, err := http.Post(s.serverURL+"/gerbil/get-all-relays", "application/json", body)
if err != nil {
return fmt.Errorf("failed to fetch mappings: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %v", err)
}
logger.Info("Received initial mappings: %s", string(data))
var initialMappings InitialMappings
if err := json.Unmarshal(data, &initialMappings); err != nil {
return fmt.Errorf("failed to unmarshal initial mappings: %v", err)
}
// Store mappings in our sync.Map.
for key, mapping := range initialMappings.Mappings {
s.proxyMappings.Store(key, mapping)
}
logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings))
return nil
}
// Extract WireGuard message indices
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
if len(packet) < 12 {
return 0, 0, false
}
messageType := packet[0]
if messageType == WireGuardMessageTypeHandshakeInitiation {
// Handshake initiation: extract sender index at offset 4
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
return 0, senderIndex, true
} else if messageType == WireGuardMessageTypeHandshakeResponse {
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
return receiverIndex, senderIndex, true
} else if messageType == WireGuardMessageTypeTransportData {
// Transport data: extract receiver index at offset 4
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
return receiverIndex, 0, true
}
return 0, 0, false
}
// Updated to handle multi-peer WireGuard communication
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
if len(packet) == 0 {
logger.Error("Received empty packet")
return
}
messageType := packet[0]
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
if !ok {
logger.Error("Failed to extract WireGuard indices")
return
}
key := remoteAddr.String()
mappingObj, ok := s.proxyMappings.Load(key)
if !ok {
logger.Error("No proxy mapping found for %s", key)
return
}
proxyMapping := mappingObj.(ProxyMapping)
// Handle different WireGuard message types
switch messageType {
case WireGuardMessageTypeHandshakeInitiation:
// Initial handshake: forward to all peers
logger.Debug("Forwarding handshake initiation from %s (sender index: %d)", remoteAddr, senderIndex)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake initiation: %v", err)
}
}
case WireGuardMessageTypeHandshakeResponse:
// Received handshake response: establish session mapping
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
receiverIndex, senderIndex, remoteAddr)
// Create a session key for the peer that sent the initial handshake
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
// Store the session information
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: remoteAddr,
LastSeen: time.Now(),
})
// Forward the response to the original sender
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward handshake response: %v", err)
}
}
case WireGuardMessageTypeTransportData:
// Data packet: forward only to the established session peer
logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
// Look up the session based on the receiver index
var destAddr *net.UDPAddr
// First check for existing sessions to see if we know where to send this packet
s.wgSessions.Range(func(k, v interface{}) bool {
session := v.(*WireGuardSession)
if session.SenderIndex == receiverIndex {
// Found matching session
destAddr = session.DestAddr
// Update last seen time
session.LastSeen = time.Now()
s.wgSessions.Store(k, session)
return false // stop iteration
}
return true // continue iteration
})
if destAddr != nil {
// We found a specific peer to forward to
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
return
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward transport data: %v", err)
}
} else {
// No known session, fall back to forwarding to all peers
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward transport data: %v", err)
}
}
}
default:
// Other packet types (like cookie reply)
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
// Forward to all peers
for _, dest := range proxyMapping.Destinations {
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
if err != nil {
logger.Error("Failed to resolve destination address: %v", err)
continue
}
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
if err != nil {
logger.Error("Failed to get/create connection: %v", err)
continue
}
_, err = conn.Write(packet)
if err != nil {
logger.Error("Failed to forward WireGuard packet: %v", err)
}
}
}
}
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
key := destAddr.String() + "-" + remoteAddr.String()
// Check if we have an existing connection
if conn, ok := s.connections.Load(key); ok {
destConn := conn.(*DestinationConn)
destConn.lastUsed = time.Now()
return destConn.conn, nil
}
// Create new connection
newConn, err := net.DialUDP("udp", nil, destAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP connection: %v", err)
}
// Store the new connection
s.connections.Store(key, &DestinationConn{
conn: newConn,
lastUsed: time.Now(),
})
// Start a goroutine to handle responses
go s.handleResponses(newConn, destAddr, remoteAddr)
return newConn, nil
}
func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) {
buffer := make([]byte, 1500)
for {
n, err := conn.Read(buffer)
if err != nil {
logger.Error("Error reading response from %s: %v", destAddr.String(), err)
return
}
// Process the response to track sessions if it's a WireGuard packet
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
// Store the session mapping for the handshake response
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: receiverIndex,
SenderIndex: senderIndex,
DestAddr: destAddr,
LastSeen: time.Now(),
})
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
}
}
// Forward the response back through the main listener
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
if err != nil {
logger.Error("Failed to forward response: %v", err)
}
}
}
// Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
if now.Sub(destConn.lastUsed) > 10*time.Minute {
destConn.conn.Close()
s.connections.Delete(key)
}
return true
})
}
}
// New method to periodically remove idle sessions
func (s *UDPProxyServer) cleanupIdleSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
if now.Sub(session.LastSeen) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
return true
})
}
}
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
jsonData, err := json.Marshal(endpoint)
if err != nil {
logger.Error("Failed to marshal endpoint data: %v", err)
return
}
resp, err := http.Post(s.serverURL+"/gerbil/update-hole-punch", "application/json", bytes.NewBuffer(jsonData))
if err != nil {
logger.Error("Failed to notify server: %v", err)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Server returned non-OK status: %d, body: %s",
resp.StatusCode, string(body))
return
}
// Parse the proxy mapping response
var mapping ProxyMapping
if err := json.NewDecoder(resp.Body).Decode(&mapping); err != nil {
logger.Error("Failed to decode proxy mapping: %v", err)
return
}
logger.Debug("Received proxy mapping: %v", mapping)
// Store the mapping
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
s.proxyMappings.Store(key, mapping)
logger.Debug("Stored proxy mapping for %s with %d destinations", key, len(mapping.Destinations))
}
// Updated to support multiple destinations
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
key := net.JoinHostPort(sourceIP, strconv.Itoa(sourcePort))
mapping := ProxyMapping{
Destinations: destinations,
}
s.proxyMappings.Store(key, mapping)
}