From ee27bf3153e4054d75f3d39ae7d94f0c697857ce Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 13 Nov 2025 06:26:09 +0000 Subject: [PATCH 1/2] Fix race condition in WireGuard session management The race condition existed because while sync.Map is thread-safe for map operations (Load, Store, Delete, Range), it does not provide thread-safety for the data stored within it. When WireGuardSession structs were stored as pointers in the sync.Map, multiple goroutines could: 1. Retrieve the same session pointer from the map concurrently 2. Access and modify the session's fields (particularly LastSeen) without synchronization 3. Cause data races when one goroutine reads LastSeen while another updates it This fix adds a sync.RWMutex to each WireGuardSession struct to protect concurrent access to its fields. All field access now goes through thread-safe methods that properly acquire/release the mutex. Changes: - Added sync.RWMutex to WireGuardSession struct - Added thread-safe accessor methods (GetLastSeen, GetDestAddr, etc.) - Added atomic CheckAndUpdateIfMatch method for efficient check-and-update - Updated all session field accesses to use thread-safe methods - Removed redundant Store call after updating LastSeen (pointer update is atomic in Go, but field access within pointer was not) --- relay/relay.go | 66 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index e74ed87..aa2045d 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -58,12 +58,61 @@ type DestinationConn struct { // Type for storing WireGuard handshake information type WireGuardSession struct { + mu sync.RWMutex ReceiverIndex uint32 SenderIndex uint32 DestAddr *net.UDPAddr LastSeen time.Time } +// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) UpdateLastSeen() { + s.mu.Lock() + defer s.mu.Unlock() + s.LastSeen = time.Now() +} + +// GetSenderIndex returns the SenderIndex in a thread-safe manner +func (s *WireGuardSession) GetSenderIndex() uint32 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.SenderIndex +} + +// GetDestAddr returns the DestAddr in a thread-safe manner +func (s *WireGuardSession) GetDestAddr() *net.UDPAddr { + s.mu.RLock() + defer s.mu.RUnlock() + return s.DestAddr +} + +// GetLastSeen returns the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) GetLastSeen() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.LastSeen +} + +// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner +func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.SenderIndex == receiverIndex +} + +// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does. +// Returns the DestAddr and true if there's a match, nil and false otherwise. +// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls. +func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.SenderIndex == receiverIndex { + s.LastSeen = time.Now() + return s.DestAddr, true + } + return nil, false +} + // Type for tracking bidirectional communication patterns to rebuild sessions type CommunicationPattern struct { FromClient *net.UDPAddr // The client address @@ -442,13 +491,10 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // First check for existing sessions to see if we know where to send this packet s.wgSessions.Range(func(k, v interface{}) bool { session := v.(*WireGuardSession) - if session.SenderIndex == receiverIndex { + // Atomically check if session matches and update LastSeen if it does + if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches { // Found matching session - destAddr = session.DestAddr - - // Update last seen time - session.LastSeen = time.Now() - s.wgSessions.Store(k, session) + destAddr = addr return false // stop iteration } return true // continue iteration @@ -608,7 +654,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() { now := time.Now() s.wgSessions.Range(func(key, value interface{}) bool { session := value.(*WireGuardSession) - if now.Sub(session.LastSeen) > 15*time.Minute { + // Use thread-safe method to read LastSeen + if now.Sub(session.GetLastSeen()) > 15*time.Minute { s.wgSessions.Delete(key) logger.Debug("Removed idle session: %s", key) } @@ -735,8 +782,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { keyStr := key.(string) session := value.(*WireGuardSession) - // Check if the session's destination address contains the WG IP - if session.DestAddr != nil && session.DestAddr.IP.String() == ip { + // Check if the session's destination address contains the WG IP (thread-safe) + destAddr := session.GetDestAddr() + if destAddr != nil && destAddr.IP.String() == ip { keysToDelete = append(keysToDelete, keyStr) logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr) } From a3f9a89079eb6b5babb6bb32c14936d3fb9c3799 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 13 Nov 2025 06:43:31 +0000 Subject: [PATCH 2/2] Refactor WireGuard session locking and remove unused methods - Remove unused methods: UpdateLastSeen, GetSenderIndex, MatchesSenderIndex (replaced by simpler direct usage in Range callbacks) - Simplify session access pattern: check GetSenderIndex in Range callback, then call GetDestAddr and UpdateLastSeen when match found - Optimize UpdateLastSeen usage: only use for existing sessions already in sync.Map; use direct assignment in struct literals for new sessions (safe since no concurrent access during creation) This simplifies the code while maintaining thread-safety for concurrent access to existing sessions. --- relay/relay.go | 41 ++++++++++------------------------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index aa2045d..595cbb5 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -65,13 +65,6 @@ type WireGuardSession struct { LastSeen time.Time } -// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner -func (s *WireGuardSession) UpdateLastSeen() { - s.mu.Lock() - defer s.mu.Unlock() - s.LastSeen = time.Now() -} - // GetSenderIndex returns the SenderIndex in a thread-safe manner func (s *WireGuardSession) GetSenderIndex() uint32 { s.mu.RLock() @@ -93,24 +86,11 @@ func (s *WireGuardSession) GetLastSeen() time.Time { return s.LastSeen } -// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner -func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.SenderIndex == receiverIndex -} - -// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does. -// Returns the DestAddr and true if there's a match, nil and false otherwise. -// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls. -func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) { +// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) UpdateLastSeen() { s.mu.Lock() defer s.mu.Unlock() - if s.SenderIndex == receiverIndex { - s.LastSeen = time.Now() - return s.DestAddr, true - } - return nil, false + s.LastSeen = time.Now() } // Type for tracking bidirectional communication patterns to rebuild sessions @@ -491,10 +471,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // First check for existing sessions to see if we know where to send this packet s.wgSessions.Range(func(k, v interface{}) bool { session := v.(*WireGuardSession) - // Atomically check if session matches and update LastSeen if it does - if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches { - // Found matching session - destAddr = addr + // Check if session matches (read lock for check) + if session.GetSenderIndex() == receiverIndex { + // Found matching session - get dest addr and update last seen + destAddr = session.GetDestAddr() + session.UpdateLastSeen() return false // stop iteration } return true // continue iteration @@ -974,14 +955,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { // Check if we already have this session if _, exists := s.wgSessions.Load(sessionKey); !exists { - session := &WireGuardSession{ + s.wgSessions.Store(sessionKey, &WireGuardSession{ ReceiverIndex: pattern.DestIndex, SenderIndex: pattern.ClientIndex, DestAddr: pattern.ToDestination, LastSeen: time.Now(), - } - - s.wgSessions.Store(sessionKey, session) + }) logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", sessionKey, pattern.ToDestination.String(), pattern.PacketCount) }