mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-22 12:56:40 +00:00
Attempt to add sender and receiver ids to relaying
This commit is contained in:
273
relay/relay.go
273
relay/relay.go
@@ -2,6 +2,7 @@ package relay
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -38,7 +39,12 @@ type ClientEndpoint struct {
|
|||||||
Timestamp int64 `json:"timestamp"`
|
Timestamp int64 `json:"timestamp"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Updated to support multiple destination peers
|
||||||
type ProxyMapping struct {
|
type ProxyMapping struct {
|
||||||
|
Destinations []PeerDestination `json:"destinations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PeerDestination struct {
|
||||||
DestinationIP string `json:"destinationIP"`
|
DestinationIP string `json:"destinationIP"`
|
||||||
DestinationPort int `json:"destinationPort"`
|
DestinationPort int `json:"destinationPort"`
|
||||||
}
|
}
|
||||||
@@ -48,6 +54,14 @@ type DestinationConn struct {
|
|||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type for storing WireGuard handshake information
|
||||||
|
type WireGuardSession struct {
|
||||||
|
ReceiverIndex uint32
|
||||||
|
SenderIndex uint32
|
||||||
|
DestAddr *net.UDPAddr
|
||||||
|
LastSeen time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type InitialMappings struct {
|
type InitialMappings struct {
|
||||||
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
Mappings map[string]ProxyMapping `json:"mappings"` // key is "ip:port"
|
||||||
}
|
}
|
||||||
@@ -59,6 +73,14 @@ type Packet struct {
|
|||||||
n int
|
n int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WireGuard message types
|
||||||
|
const (
|
||||||
|
WireGuardMessageTypeHandshakeInitiation = 1
|
||||||
|
WireGuardMessageTypeHandshakeResponse = 2
|
||||||
|
WireGuardMessageTypeCookieReply = 3
|
||||||
|
WireGuardMessageTypeTransportData = 4
|
||||||
|
)
|
||||||
|
|
||||||
// --- End Types ---
|
// --- End Types ---
|
||||||
|
|
||||||
// bufferPool allows reusing buffers to reduce allocations.
|
// bufferPool allows reusing buffers to reduce allocations.
|
||||||
@@ -77,6 +99,10 @@ type UDPProxyServer struct {
|
|||||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
packetChan chan Packet
|
packetChan chan Packet
|
||||||
|
|
||||||
|
// Session tracking for WireGuard peers
|
||||||
|
// Key format: "senderIndex:receiverIndex"
|
||||||
|
wgSessions sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||||
@@ -119,6 +145,9 @@ func (s *UDPProxyServer) Start() error {
|
|||||||
// Start the idle connection cleanup routine.
|
// Start the idle connection cleanup routine.
|
||||||
go s.cleanupIdleConnections()
|
go s.cleanupIdleConnections()
|
||||||
|
|
||||||
|
// Start the session cleanup routine
|
||||||
|
go s.cleanupIdleSessions()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,34 +288,201 @@ func (s *UDPProxyServer) fetchInitialMappings() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Example handleWireGuardPacket remains unchanged.
|
// Extract WireGuard message indices
|
||||||
|
func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) {
|
||||||
|
if len(packet) < 12 {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType := packet[0]
|
||||||
|
if messageType == WireGuardMessageTypeHandshakeInitiation {
|
||||||
|
// Handshake initiation: extract sender index at offset 4
|
||||||
|
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
||||||
|
return 0, senderIndex, true
|
||||||
|
} else if messageType == WireGuardMessageTypeHandshakeResponse {
|
||||||
|
// Handshake response: extract sender index at offset 4 and receiver index at offset 8
|
||||||
|
senderIndex := binary.LittleEndian.Uint32(packet[4:8])
|
||||||
|
receiverIndex := binary.LittleEndian.Uint32(packet[8:12])
|
||||||
|
return receiverIndex, senderIndex, true
|
||||||
|
} else if messageType == WireGuardMessageTypeTransportData {
|
||||||
|
// Transport data: extract receiver index at offset 4
|
||||||
|
receiverIndex := binary.LittleEndian.Uint32(packet[4:8])
|
||||||
|
return receiverIndex, 0, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updated to handle multi-peer WireGuard communication
|
||||||
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) {
|
||||||
|
if len(packet) == 0 {
|
||||||
|
logger.Error("Received empty packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
messageType := packet[0]
|
||||||
|
receiverIndex, senderIndex, ok := extractWireGuardIndices(packet)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
logger.Error("Failed to extract WireGuard indices")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
key := remoteAddr.String()
|
key := remoteAddr.String()
|
||||||
mapping, ok := s.proxyMappings.Load(key)
|
mappingObj, ok := s.proxyMappings.Load(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Error("No proxy mapping found for %s", key)
|
logger.Error("No proxy mapping found for %s", key)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
proxyMapping := mapping.(ProxyMapping)
|
|
||||||
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d",
|
proxyMapping := mappingObj.(ProxyMapping)
|
||||||
proxyMapping.DestinationIP, proxyMapping.DestinationPort))
|
|
||||||
if err != nil {
|
// Handle different WireGuard message types
|
||||||
logger.Error("Failed to resolve destination address: %v", err)
|
switch messageType {
|
||||||
return
|
case WireGuardMessageTypeHandshakeInitiation:
|
||||||
}
|
// Initial handshake: forward to all peers
|
||||||
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
logger.Debug("Forwarding handshake initiation from %s (sender index: %d)", remoteAddr, senderIndex)
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to get/create connection: %v", err)
|
for _, dest := range proxyMapping.Destinations {
|
||||||
return
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||||
}
|
if err != nil {
|
||||||
_, err = conn.Write(packet)
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
if err != nil {
|
continue
|
||||||
logger.Error("Failed to proxy packet: %v", err)
|
}
|
||||||
|
|
||||||
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get/create connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to forward handshake initiation: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case WireGuardMessageTypeHandshakeResponse:
|
||||||
|
// Received handshake response: establish session mapping
|
||||||
|
logger.Debug("Received handshake response with receiver index %d and sender index %d from %s",
|
||||||
|
receiverIndex, senderIndex, remoteAddr)
|
||||||
|
|
||||||
|
// Create a session key for the peer that sent the initial handshake
|
||||||
|
sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex)
|
||||||
|
|
||||||
|
// Store the session information
|
||||||
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||||
|
ReceiverIndex: receiverIndex,
|
||||||
|
SenderIndex: senderIndex,
|
||||||
|
DestAddr: remoteAddr,
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Forward the response to the original sender
|
||||||
|
for _, dest := range proxyMapping.Destinations {
|
||||||
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get/create connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to forward handshake response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case WireGuardMessageTypeTransportData:
|
||||||
|
// Data packet: forward only to the established session peer
|
||||||
|
logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr)
|
||||||
|
|
||||||
|
// Look up the session based on the receiver index
|
||||||
|
var destAddr *net.UDPAddr
|
||||||
|
|
||||||
|
// First check for existing sessions to see if we know where to send this packet
|
||||||
|
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||||
|
session := v.(*WireGuardSession)
|
||||||
|
if session.SenderIndex == receiverIndex {
|
||||||
|
// Found matching session
|
||||||
|
destAddr = session.DestAddr
|
||||||
|
|
||||||
|
// Update last seen time
|
||||||
|
session.LastSeen = time.Now()
|
||||||
|
s.wgSessions.Store(k, session)
|
||||||
|
return false // stop iteration
|
||||||
|
}
|
||||||
|
return true // continue iteration
|
||||||
|
})
|
||||||
|
|
||||||
|
if destAddr != nil {
|
||||||
|
// We found a specific peer to forward to
|
||||||
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get/create connection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to forward transport data: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No known session, fall back to forwarding to all peers
|
||||||
|
logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex)
|
||||||
|
for _, dest := range proxyMapping.Destinations {
|
||||||
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get/create connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to forward transport data: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Other packet types (like cookie reply)
|
||||||
|
logger.Debug("Forwarding WireGuard packet type %d from %s", messageType, remoteAddr)
|
||||||
|
|
||||||
|
// Forward to all peers
|
||||||
|
for _, dest := range proxyMapping.Destinations {
|
||||||
|
destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to resolve destination address: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := s.getOrCreateConnection(destAddr, remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to get/create connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to forward WireGuard packet: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
|
func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
key := remoteAddr.String()
|
key := destAddr.String() + "-" + remoteAddr.String()
|
||||||
|
|
||||||
// Check if we have an existing connection
|
// Check if we have an existing connection
|
||||||
if conn, ok := s.connections.Load(key); ok {
|
if conn, ok := s.connections.Load(key); ok {
|
||||||
@@ -322,6 +518,22 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process the response to track sessions if it's a WireGuard packet
|
||||||
|
if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 {
|
||||||
|
receiverIndex, senderIndex, ok := extractWireGuardIndices(buffer[:n])
|
||||||
|
if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse {
|
||||||
|
// Store the session mapping for the handshake response
|
||||||
|
sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex)
|
||||||
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||||
|
ReceiverIndex: receiverIndex,
|
||||||
|
SenderIndex: senderIndex,
|
||||||
|
DestAddr: destAddr,
|
||||||
|
LastSeen: time.Now(),
|
||||||
|
})
|
||||||
|
logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Forward the response back through the main listener
|
// Forward the response back through the main listener
|
||||||
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
|
_, err = s.conn.WriteToUDP(buffer[:n], remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -346,6 +558,22 @@ func (s *UDPProxyServer) cleanupIdleConnections() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New method to periodically remove idle sessions
|
||||||
|
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||||
|
ticker := time.NewTicker(5 * time.Minute)
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||||
|
session := value.(*WireGuardSession)
|
||||||
|
if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||||
|
s.wgSessions.Delete(key)
|
||||||
|
logger.Debug("Removed idle session: %s", key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
||||||
jsonData, err := json.Marshal(endpoint)
|
jsonData, err := json.Marshal(endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -380,15 +608,14 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) {
|
|||||||
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
|
key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port)
|
||||||
s.proxyMappings.Store(key, mapping)
|
s.proxyMappings.Store(key, mapping)
|
||||||
|
|
||||||
logger.Debug("Stored proxy mapping for %s: %v", key, mapping)
|
logger.Debug("Stored proxy mapping for %s with %d destinations", key, len(mapping.Destinations))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int,
|
// Updated to support multiple destinations
|
||||||
destinationIP string, destinationPort int) {
|
func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, destinations []PeerDestination) {
|
||||||
key := net.JoinHostPort(sourceIP, strconv.Itoa(sourcePort))
|
key := net.JoinHostPort(sourceIP, strconv.Itoa(sourcePort))
|
||||||
mapping := ProxyMapping{
|
mapping := ProxyMapping{
|
||||||
DestinationIP: destinationIP,
|
Destinations: destinations,
|
||||||
DestinationPort: destinationPort,
|
|
||||||
}
|
}
|
||||||
s.proxyMappings.Store(key, mapping)
|
s.proxyMappings.Store(key, mapping)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user