Fix race condition in WireGuard session management

The race condition existed because while sync.Map is thread-safe for map
operations (Load, Store, Delete, Range), it does not provide thread-safety
for the data stored within it. When WireGuardSession structs were stored as
pointers in the sync.Map, multiple goroutines could:

1. Retrieve the same session pointer from the map concurrently
2. Access and modify the session's fields (particularly LastSeen) without
   synchronization
3. Cause data races when one goroutine reads LastSeen while another updates it

This fix adds a sync.RWMutex to each WireGuardSession struct to protect
concurrent access to its fields. All field access now goes through
thread-safe methods that properly acquire/release the mutex.

Changes:
- Added sync.RWMutex to WireGuardSession struct
- Added thread-safe accessor methods (GetLastSeen, GetDestAddr, etc.)
- Added atomic CheckAndUpdateIfMatch method for efficient check-and-update
- Updated all session field accesses to use thread-safe methods
- Removed redundant Store call after updating LastSeen (pointer update is
  atomic in Go, but field access within pointer was not)
This commit is contained in:
Laurence
2025-11-13 06:26:09 +00:00
parent 2a1911a66f
commit ee27bf3153

View File

@@ -58,12 +58,61 @@ type DestinationConn struct {
// Type for storing WireGuard handshake information
type WireGuardSession struct {
mu sync.RWMutex
ReceiverIndex uint32
SenderIndex uint32
DestAddr *net.UDPAddr
LastSeen time.Time
}
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
func (s *WireGuardSession) UpdateLastSeen() {
s.mu.Lock()
defer s.mu.Unlock()
s.LastSeen = time.Now()
}
// GetSenderIndex returns the SenderIndex in a thread-safe manner
func (s *WireGuardSession) GetSenderIndex() uint32 {
s.mu.RLock()
defer s.mu.RUnlock()
return s.SenderIndex
}
// GetDestAddr returns the DestAddr in a thread-safe manner
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
s.mu.RLock()
defer s.mu.RUnlock()
return s.DestAddr
}
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
func (s *WireGuardSession) GetLastSeen() time.Time {
s.mu.RLock()
defer s.mu.RUnlock()
return s.LastSeen
}
// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner
func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.SenderIndex == receiverIndex
}
// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does.
// Returns the DestAddr and true if there's a match, nil and false otherwise.
// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls.
func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.SenderIndex == receiverIndex {
s.LastSeen = time.Now()
return s.DestAddr, true
}
return nil, false
}
// Type for tracking bidirectional communication patterns to rebuild sessions
type CommunicationPattern struct {
FromClient *net.UDPAddr // The client address
@@ -442,13 +491,10 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
// First check for existing sessions to see if we know where to send this packet
s.wgSessions.Range(func(k, v interface{}) bool {
session := v.(*WireGuardSession)
if session.SenderIndex == receiverIndex {
// Atomically check if session matches and update LastSeen if it does
if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches {
// Found matching session
destAddr = session.DestAddr
// Update last seen time
session.LastSeen = time.Now()
s.wgSessions.Store(k, session)
destAddr = addr
return false // stop iteration
}
return true // continue iteration
@@ -608,7 +654,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() {
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
if now.Sub(session.LastSeen) > 15*time.Minute {
// Use thread-safe method to read LastSeen
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
@@ -735,8 +782,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
keyStr := key.(string)
session := value.(*WireGuardSession)
// Check if the session's destination address contains the WG IP
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
// Check if the session's destination address contains the WG IP (thread-safe)
destAddr := session.GetDestAddr()
if destAddr != nil && destAddr.IP.String() == ip {
keysToDelete = append(keysToDelete, keyStr)
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
}