diff --git a/relay/relay.go b/relay/relay.go index e0a6a98..a01ce42 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -58,12 +58,41 @@ type DestinationConn struct { // Type for storing WireGuard handshake information type WireGuardSession struct { + mu sync.RWMutex ReceiverIndex uint32 SenderIndex uint32 DestAddr *net.UDPAddr LastSeen time.Time } +// GetSenderIndex returns the SenderIndex in a thread-safe manner +func (s *WireGuardSession) GetSenderIndex() uint32 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.SenderIndex +} + +// GetDestAddr returns the DestAddr in a thread-safe manner +func (s *WireGuardSession) GetDestAddr() *net.UDPAddr { + s.mu.RLock() + defer s.mu.RUnlock() + return s.DestAddr +} + +// GetLastSeen returns the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) GetLastSeen() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() + return s.LastSeen +} + +// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) UpdateLastSeen() { + s.mu.Lock() + defer s.mu.Unlock() + s.LastSeen = time.Now() +} + // Type for tracking bidirectional communication patterns to rebuild sessions type CommunicationPattern struct { FromClient *net.UDPAddr // The client address @@ -444,13 +473,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // First check for existing sessions to see if we know where to send this packet s.wgSessions.Range(func(k, v interface{}) bool { session := v.(*WireGuardSession) - if session.SenderIndex == receiverIndex { - // Found matching session - destAddr = session.DestAddr - - // Update last seen time - session.LastSeen = time.Now() - s.wgSessions.Store(k, session) + // Check if session matches (read lock for check) + if session.GetSenderIndex() == receiverIndex { + // Found matching session - get dest addr and update last seen + destAddr = session.GetDestAddr() + session.UpdateLastSeen() return false // stop iteration } return true // continue iteration @@ -610,7 +637,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() { now := time.Now() s.wgSessions.Range(func(key, value interface{}) bool { session := value.(*WireGuardSession) - if now.Sub(session.LastSeen) > 15*time.Minute { + // Use thread-safe method to read LastSeen + if now.Sub(session.GetLastSeen()) > 15*time.Minute { s.wgSessions.Delete(key) logger.Debug("Removed idle session: %s", key) } @@ -737,8 +765,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { keyStr := key.(string) session := value.(*WireGuardSession) - // Check if the session's destination address contains the WG IP - if session.DestAddr != nil && session.DestAddr.IP.String() == ip { + // Check if the session's destination address contains the WG IP (thread-safe) + destAddr := session.GetDestAddr() + if destAddr != nil && destAddr.IP.String() == ip { keysToDelete = append(keysToDelete, keyStr) logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr) } @@ -928,14 +957,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { // Check if we already have this session if _, exists := s.wgSessions.Load(sessionKey); !exists { - session := &WireGuardSession{ + s.wgSessions.Store(sessionKey, &WireGuardSession{ ReceiverIndex: pattern.DestIndex, SenderIndex: pattern.ClientIndex, DestAddr: pattern.ToDestination, LastSeen: time.Now(), - } - - s.wgSessions.Store(sessionKey, session) + }) logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", sessionKey, pattern.ToDestination.String(), pattern.PacketCount) }