diff --git a/relay/relay.go b/relay/relay.go index e74ed87..aa2045d 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -58,12 +58,61 @@ type DestinationConn struct { // Type for storing WireGuard handshake information type WireGuardSession struct { + mu sync.RWMutex ReceiverIndex uint32 SenderIndex uint32 DestAddr *net.UDPAddr LastSeen time.Time } +// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) UpdateLastSeen() { + s.mu.Lock() + defer s.mu.Unlock() + s.LastSeen = time.Now() +} + +// GetSenderIndex returns the SenderIndex in a thread-safe manner +func (s *WireGuardSession) GetSenderIndex() uint32 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.SenderIndex +} + +// GetDestAddr returns the DestAddr in a thread-safe manner +func (s *WireGuardSession) GetDestAddr() *net.UDPAddr { + s.mu.RLock() + defer s.mu.RUnlock() + return s.DestAddr +} + +// GetLastSeen returns the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) GetLastSeen() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.LastSeen +} + +// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner +func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.SenderIndex == receiverIndex +} + +// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does. +// Returns the DestAddr and true if there's a match, nil and false otherwise. +// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls. +func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.SenderIndex == receiverIndex { + s.LastSeen = time.Now() + return s.DestAddr, true + } + return nil, false +} + // Type for tracking bidirectional communication patterns to rebuild sessions type CommunicationPattern struct { FromClient *net.UDPAddr // The client address @@ -442,13 +491,10 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // First check for existing sessions to see if we know where to send this packet s.wgSessions.Range(func(k, v interface{}) bool { session := v.(*WireGuardSession) - if session.SenderIndex == receiverIndex { + // Atomically check if session matches and update LastSeen if it does + if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches { // Found matching session - destAddr = session.DestAddr - - // Update last seen time - session.LastSeen = time.Now() - s.wgSessions.Store(k, session) + destAddr = addr return false // stop iteration } return true // continue iteration @@ -608,7 +654,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() { now := time.Now() s.wgSessions.Range(func(key, value interface{}) bool { session := value.(*WireGuardSession) - if now.Sub(session.LastSeen) > 15*time.Minute { + // Use thread-safe method to read LastSeen + if now.Sub(session.GetLastSeen()) > 15*time.Minute { s.wgSessions.Delete(key) logger.Debug("Removed idle session: %s", key) } @@ -735,8 +782,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { keyStr := key.(string) session := value.(*WireGuardSession) - // Check if the session's destination address contains the WG IP - if session.DestAddr != nil && session.DestAddr.IP.String() == ip { + // Check if the session's destination address contains the WG IP (thread-safe) + destAddr := session.GetDestAddr() + if destAddr != nil && destAddr.IP.String() == ip { keysToDelete = append(keysToDelete, keyStr) logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr) }