mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
Merge branch 'main' into dev
This commit is contained in:
222
relay/relay.go
222
relay/relay.go
@@ -2,6 +2,7 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -60,12 +61,41 @@ type DestinationConn struct {
|
||||
|
||||
// Type for storing WireGuard handshake information
|
||||
type WireGuardSession struct {
|
||||
mu sync.RWMutex
|
||||
ReceiverIndex uint32
|
||||
SenderIndex uint32
|
||||
DestAddr *net.UDPAddr
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// GetSenderIndex returns the SenderIndex in a thread-safe manner
|
||||
func (s *WireGuardSession) GetSenderIndex() uint32 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.SenderIndex
|
||||
}
|
||||
|
||||
// GetDestAddr returns the DestAddr in a thread-safe manner
|
||||
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.DestAddr
|
||||
}
|
||||
|
||||
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) GetLastSeen() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.LastSeen
|
||||
}
|
||||
|
||||
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) UpdateLastSeen() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.LastSeen = time.Now()
|
||||
}
|
||||
|
||||
// Type for tracking bidirectional communication patterns to rebuild sessions
|
||||
type CommunicationPattern struct {
|
||||
FromClient *net.UDPAddr // The client address
|
||||
@@ -114,6 +144,8 @@ type UDPProxyServer struct {
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
@@ -125,14 +157,17 @@ type UDPProxyServer struct {
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,17 +214,51 @@ func (s *UDPProxyServer) Start() error {
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
s.conn.Close()
|
||||
// Signal all background goroutines to stop
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
// Close listener to unblock reads
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
// Close all downstream UDP connections
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Close packet channel to stop workers
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
default:
|
||||
}
|
||||
close(s.packetChan)
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
// Exit promptly if context is canceled
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
// If we're shutting down, exit
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
bufferPool.Put(buf[:1500])
|
||||
return
|
||||
default:
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
bufferPool.Put(buf[:1500])
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
@@ -445,13 +514,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
// First check for existing sessions to see if we know where to send this packet
|
||||
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||
session := v.(*WireGuardSession)
|
||||
if session.SenderIndex == receiverIndex {
|
||||
// Found matching session
|
||||
destAddr = session.DestAddr
|
||||
|
||||
// Update last seen time
|
||||
session.LastSeen = time.Now()
|
||||
s.wgSessions.Store(k, session)
|
||||
// Check if session matches (read lock for check)
|
||||
if session.GetSenderIndex() == receiverIndex {
|
||||
// Found matching session - get dest addr and update last seen
|
||||
destAddr = session.GetDestAddr()
|
||||
session.UpdateLastSeen()
|
||||
return false // stop iteration
|
||||
}
|
||||
return true // continue iteration
|
||||
@@ -591,49 +658,69 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -738,8 +825,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
keyStr := key.(string)
|
||||
session := value.(*WireGuardSession)
|
||||
|
||||
// Check if the session's destination address contains the WG IP
|
||||
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
||||
// Check if the session's destination address contains the WG IP (thread-safe)
|
||||
destAddr := session.GetDestAddr()
|
||||
if destAddr != nil && destAddr.IP.String() == ip {
|
||||
keysToDelete = append(keysToDelete, keyStr)
|
||||
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
||||
}
|
||||
@@ -929,14 +1017,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
|
||||
// Check if we already have this session
|
||||
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
||||
session := &WireGuardSession{
|
||||
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||
ReceiverIndex: pattern.DestIndex,
|
||||
SenderIndex: pattern.ClientIndex,
|
||||
DestAddr: pattern.ToDestination,
|
||||
LastSeen: time.Now(),
|
||||
}
|
||||
|
||||
s.wgSessions.Store(sessionKey, session)
|
||||
})
|
||||
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
||||
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
||||
}
|
||||
@@ -946,23 +1032,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user