Merge pull request #36 from LaurenceJJones/fix-wg-session-race-condition

fix: relay race condition in WireGuard session management
This commit is contained in:
Owen Schwartz
2025-12-06 12:12:04 -05:00
committed by GitHub

View File

@@ -58,12 +58,41 @@ type DestinationConn struct {
// Type for storing WireGuard handshake information // Type for storing WireGuard handshake information
type WireGuardSession struct { type WireGuardSession struct {
mu sync.RWMutex
ReceiverIndex uint32 ReceiverIndex uint32
SenderIndex uint32 SenderIndex uint32
DestAddr *net.UDPAddr DestAddr *net.UDPAddr
LastSeen time.Time LastSeen time.Time
} }
// GetSenderIndex returns the SenderIndex in a thread-safe manner
func (s *WireGuardSession) GetSenderIndex() uint32 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.SenderIndex
}
// GetDestAddr returns the DestAddr in a thread-safe manner
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
s.mu.RLock()
defer s.mu.RUnlock()
return s.DestAddr
}
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
func (s *WireGuardSession) GetLastSeen() time.Time {
s.mu.RLock()
defer s.mu.RUnlock()
return s.LastSeen
}
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
func (s *WireGuardSession) UpdateLastSeen() {
s.mu.Lock()
defer s.mu.Unlock()
s.LastSeen = time.Now()
}
// Type for tracking bidirectional communication patterns to rebuild sessions // Type for tracking bidirectional communication patterns to rebuild sessions
type CommunicationPattern struct { type CommunicationPattern struct {
FromClient *net.UDPAddr // The client address FromClient *net.UDPAddr // The client address
@@ -444,13 +473,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
// First check for existing sessions to see if we know where to send this packet // First check for existing sessions to see if we know where to send this packet
s.wgSessions.Range(func(k, v interface{}) bool { s.wgSessions.Range(func(k, v interface{}) bool {
session := v.(*WireGuardSession) session := v.(*WireGuardSession)
if session.SenderIndex == receiverIndex { // Check if session matches (read lock for check)
// Found matching session if session.GetSenderIndex() == receiverIndex {
destAddr = session.DestAddr // Found matching session - get dest addr and update last seen
destAddr = session.GetDestAddr()
// Update last seen time session.UpdateLastSeen()
session.LastSeen = time.Now()
s.wgSessions.Store(k, session)
return false // stop iteration return false // stop iteration
} }
return true // continue iteration return true // continue iteration
@@ -610,7 +637,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() {
now := time.Now() now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool { s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession) session := value.(*WireGuardSession)
if now.Sub(session.LastSeen) > 15*time.Minute { // Use thread-safe method to read LastSeen
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
s.wgSessions.Delete(key) s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key) logger.Debug("Removed idle session: %s", key)
} }
@@ -737,8 +765,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
keyStr := key.(string) keyStr := key.(string)
session := value.(*WireGuardSession) session := value.(*WireGuardSession)
// Check if the session's destination address contains the WG IP // Check if the session's destination address contains the WG IP (thread-safe)
if session.DestAddr != nil && session.DestAddr.IP.String() == ip { destAddr := session.GetDestAddr()
if destAddr != nil && destAddr.IP.String() == ip {
keysToDelete = append(keysToDelete, keyStr) keysToDelete = append(keysToDelete, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr) logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
} }
@@ -928,14 +957,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// Check if we already have this session // Check if we already have this session
if _, exists := s.wgSessions.Load(sessionKey); !exists { if _, exists := s.wgSessions.Load(sessionKey); !exists {
session := &WireGuardSession{ s.wgSessions.Store(sessionKey, &WireGuardSession{
ReceiverIndex: pattern.DestIndex, ReceiverIndex: pattern.DestIndex,
SenderIndex: pattern.ClientIndex, SenderIndex: pattern.ClientIndex,
DestAddr: pattern.ToDestination, DestAddr: pattern.ToDestination,
LastSeen: time.Now(), LastSeen: time.Now(),
} })
s.wgSessions.Store(sessionKey, session)
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
sessionKey, pattern.ToDestination.String(), pattern.PacketCount) sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
} }