From a3f9a89079eb6b5babb6bb32c14936d3fb9c3799 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 13 Nov 2025 06:43:31 +0000 Subject: [PATCH] Refactor WireGuard session locking and remove unused methods - Remove unused methods: UpdateLastSeen, GetSenderIndex, MatchesSenderIndex (replaced by simpler direct usage in Range callbacks) - Simplify session access pattern: check GetSenderIndex in Range callback, then call GetDestAddr and UpdateLastSeen when match found - Optimize UpdateLastSeen usage: only use for existing sessions already in sync.Map; use direct assignment in struct literals for new sessions (safe since no concurrent access during creation) This simplifies the code while maintaining thread-safety for concurrent access to existing sessions. --- relay/relay.go | 41 ++++++++++------------------------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index aa2045d..595cbb5 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -65,13 +65,6 @@ type WireGuardSession struct { LastSeen time.Time } -// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner -func (s *WireGuardSession) UpdateLastSeen() { - s.mu.Lock() - defer s.mu.Unlock() - s.LastSeen = time.Now() -} - // GetSenderIndex returns the SenderIndex in a thread-safe manner func (s *WireGuardSession) GetSenderIndex() uint32 { s.mu.RLock() @@ -93,24 +86,11 @@ func (s *WireGuardSession) GetLastSeen() time.Time { return s.LastSeen } -// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner -func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.SenderIndex == receiverIndex -} - -// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does. -// Returns the DestAddr and true if there's a match, nil and false otherwise. -// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls. -func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) { +// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner +func (s *WireGuardSession) UpdateLastSeen() { s.mu.Lock() defer s.mu.Unlock() - if s.SenderIndex == receiverIndex { - s.LastSeen = time.Now() - return s.DestAddr, true - } - return nil, false + s.LastSeen = time.Now() } // Type for tracking bidirectional communication patterns to rebuild sessions @@ -491,10 +471,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // First check for existing sessions to see if we know where to send this packet s.wgSessions.Range(func(k, v interface{}) bool { session := v.(*WireGuardSession) - // Atomically check if session matches and update LastSeen if it does - if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches { - // Found matching session - destAddr = addr + // Check if session matches (read lock for check) + if session.GetSenderIndex() == receiverIndex { + // Found matching session - get dest addr and update last seen + destAddr = session.GetDestAddr() + session.UpdateLastSeen() return false // stop iteration } return true // continue iteration @@ -974,14 +955,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { // Check if we already have this session if _, exists := s.wgSessions.Load(sessionKey); !exists { - session := &WireGuardSession{ + s.wgSessions.Store(sessionKey, &WireGuardSession{ ReceiverIndex: pattern.DestIndex, SenderIndex: pattern.ClientIndex, DestAddr: pattern.ToDestination, LastSeen: time.Now(), - } - - s.wgSessions.Store(sessionKey, session) + }) logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", sessionKey, pattern.ToDestination.String(), pattern.PacketCount) }