mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
Fix race condition in WireGuard session management
The race condition existed because while sync.Map is thread-safe for map operations (Load, Store, Delete, Range), it does not provide thread-safety for the data stored within it. When WireGuardSession structs were stored as pointers in the sync.Map, multiple goroutines could: 1. Retrieve the same session pointer from the map concurrently 2. Access and modify the session's fields (particularly LastSeen) without synchronization 3. Cause data races when one goroutine reads LastSeen while another updates it This fix adds a sync.RWMutex to each WireGuardSession struct to protect concurrent access to its fields. All field access now goes through thread-safe methods that properly acquire/release the mutex. Changes: - Added sync.RWMutex to WireGuardSession struct - Added thread-safe accessor methods (GetLastSeen, GetDestAddr, etc.) - Added atomic CheckAndUpdateIfMatch method for efficient check-and-update - Updated all session field accesses to use thread-safe methods - Removed redundant Store call after updating LastSeen (pointer update is atomic in Go, but field access within pointer was not)
This commit is contained in:
@@ -58,12 +58,61 @@ type DestinationConn struct {
|
||||
|
||||
// Type for storing WireGuard handshake information
|
||||
type WireGuardSession struct {
|
||||
mu sync.RWMutex
|
||||
ReceiverIndex uint32
|
||||
SenderIndex uint32
|
||||
DestAddr *net.UDPAddr
|
||||
LastSeen time.Time
|
||||
}
|
||||
|
||||
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) UpdateLastSeen() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.LastSeen = time.Now()
|
||||
}
|
||||
|
||||
// GetSenderIndex returns the SenderIndex in a thread-safe manner
|
||||
func (s *WireGuardSession) GetSenderIndex() uint32 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.SenderIndex
|
||||
}
|
||||
|
||||
// GetDestAddr returns the DestAddr in a thread-safe manner
|
||||
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.DestAddr
|
||||
}
|
||||
|
||||
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
|
||||
func (s *WireGuardSession) GetLastSeen() time.Time {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.LastSeen
|
||||
}
|
||||
|
||||
// MatchesSenderIndex checks if the SenderIndex matches the given value in a thread-safe manner
|
||||
func (s *WireGuardSession) MatchesSenderIndex(receiverIndex uint32) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.SenderIndex == receiverIndex
|
||||
}
|
||||
|
||||
// CheckAndUpdateIfMatch atomically checks if SenderIndex matches and updates LastSeen if it does.
|
||||
// Returns the DestAddr and true if there's a match, nil and false otherwise.
|
||||
// This is more efficient than separate MatchesSenderIndex and UpdateLastSeen calls.
|
||||
func (s *WireGuardSession) CheckAndUpdateIfMatch(receiverIndex uint32) (*net.UDPAddr, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.SenderIndex == receiverIndex {
|
||||
s.LastSeen = time.Now()
|
||||
return s.DestAddr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Type for tracking bidirectional communication patterns to rebuild sessions
|
||||
type CommunicationPattern struct {
|
||||
FromClient *net.UDPAddr // The client address
|
||||
@@ -442,13 +491,10 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
||||
// First check for existing sessions to see if we know where to send this packet
|
||||
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||
session := v.(*WireGuardSession)
|
||||
if session.SenderIndex == receiverIndex {
|
||||
// Atomically check if session matches and update LastSeen if it does
|
||||
if addr, matches := session.CheckAndUpdateIfMatch(receiverIndex); matches {
|
||||
// Found matching session
|
||||
destAddr = session.DestAddr
|
||||
|
||||
// Update last seen time
|
||||
session.LastSeen = time.Now()
|
||||
s.wgSessions.Store(k, session)
|
||||
destAddr = addr
|
||||
return false // stop iteration
|
||||
}
|
||||
return true // continue iteration
|
||||
@@ -608,7 +654,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
@@ -735,8 +782,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
||||
keyStr := key.(string)
|
||||
session := value.(*WireGuardSession)
|
||||
|
||||
// Check if the session's destination address contains the WG IP
|
||||
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
||||
// Check if the session's destination address contains the WG IP (thread-safe)
|
||||
destAddr := session.GetDestAddr()
|
||||
if destAddr != nil && destAddr.IP.String() == ip {
|
||||
keysToDelete = append(keysToDelete, keyStr)
|
||||
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user