mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-08 05:56:40 +00:00
Merge pull request #36 from LaurenceJJones/fix-wg-session-race-condition
fix: relay race condition in WireGuard session management
This commit is contained in:
@@ -58,12 +58,41 @@ type DestinationConn struct {
|
|||||||
|
|
||||||
// Type for storing WireGuard handshake information
|
// Type for storing WireGuard handshake information
|
||||||
type WireGuardSession struct {
|
type WireGuardSession struct {
|
||||||
|
mu sync.RWMutex
|
||||||
ReceiverIndex uint32
|
ReceiverIndex uint32
|
||||||
SenderIndex uint32
|
SenderIndex uint32
|
||||||
DestAddr *net.UDPAddr
|
DestAddr *net.UDPAddr
|
||||||
LastSeen time.Time
|
LastSeen time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSenderIndex returns the SenderIndex in a thread-safe manner
|
||||||
|
func (s *WireGuardSession) GetSenderIndex() uint32 {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.SenderIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDestAddr returns the DestAddr in a thread-safe manner
|
||||||
|
func (s *WireGuardSession) GetDestAddr() *net.UDPAddr {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.DestAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastSeen returns the LastSeen timestamp in a thread-safe manner
|
||||||
|
func (s *WireGuardSession) GetLastSeen() time.Time {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.LastSeen
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastSeen updates the LastSeen timestamp in a thread-safe manner
|
||||||
|
func (s *WireGuardSession) UpdateLastSeen() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.LastSeen = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
// Type for tracking bidirectional communication patterns to rebuild sessions
|
// Type for tracking bidirectional communication patterns to rebuild sessions
|
||||||
type CommunicationPattern struct {
|
type CommunicationPattern struct {
|
||||||
FromClient *net.UDPAddr // The client address
|
FromClient *net.UDPAddr // The client address
|
||||||
@@ -444,13 +473,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD
|
|||||||
// First check for existing sessions to see if we know where to send this packet
|
// First check for existing sessions to see if we know where to send this packet
|
||||||
s.wgSessions.Range(func(k, v interface{}) bool {
|
s.wgSessions.Range(func(k, v interface{}) bool {
|
||||||
session := v.(*WireGuardSession)
|
session := v.(*WireGuardSession)
|
||||||
if session.SenderIndex == receiverIndex {
|
// Check if session matches (read lock for check)
|
||||||
// Found matching session
|
if session.GetSenderIndex() == receiverIndex {
|
||||||
destAddr = session.DestAddr
|
// Found matching session - get dest addr and update last seen
|
||||||
|
destAddr = session.GetDestAddr()
|
||||||
// Update last seen time
|
session.UpdateLastSeen()
|
||||||
session.LastSeen = time.Now()
|
|
||||||
s.wgSessions.Store(k, session)
|
|
||||||
return false // stop iteration
|
return false // stop iteration
|
||||||
}
|
}
|
||||||
return true // continue iteration
|
return true // continue iteration
|
||||||
@@ -610,7 +637,8 @@ func (s *UDPProxyServer) cleanupIdleSessions() {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||||
session := value.(*WireGuardSession)
|
session := value.(*WireGuardSession)
|
||||||
if now.Sub(session.LastSeen) > 15*time.Minute {
|
// Use thread-safe method to read LastSeen
|
||||||
|
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||||
s.wgSessions.Delete(key)
|
s.wgSessions.Delete(key)
|
||||||
logger.Debug("Removed idle session: %s", key)
|
logger.Debug("Removed idle session: %s", key)
|
||||||
}
|
}
|
||||||
@@ -737,8 +765,9 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) {
|
|||||||
keyStr := key.(string)
|
keyStr := key.(string)
|
||||||
session := value.(*WireGuardSession)
|
session := value.(*WireGuardSession)
|
||||||
|
|
||||||
// Check if the session's destination address contains the WG IP
|
// Check if the session's destination address contains the WG IP (thread-safe)
|
||||||
if session.DestAddr != nil && session.DestAddr.IP.String() == ip {
|
destAddr := session.GetDestAddr()
|
||||||
|
if destAddr != nil && destAddr.IP.String() == ip {
|
||||||
keysToDelete = append(keysToDelete, keyStr)
|
keysToDelete = append(keysToDelete, keyStr)
|
||||||
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
logger.Debug("Marking session for deletion for WG IP %s: %s", ip, keyStr)
|
||||||
}
|
}
|
||||||
@@ -928,14 +957,12 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
|||||||
|
|
||||||
// Check if we already have this session
|
// Check if we already have this session
|
||||||
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
if _, exists := s.wgSessions.Load(sessionKey); !exists {
|
||||||
session := &WireGuardSession{
|
s.wgSessions.Store(sessionKey, &WireGuardSession{
|
||||||
ReceiverIndex: pattern.DestIndex,
|
ReceiverIndex: pattern.DestIndex,
|
||||||
SenderIndex: pattern.ClientIndex,
|
SenderIndex: pattern.ClientIndex,
|
||||||
DestAddr: pattern.ToDestination,
|
DestAddr: pattern.ToDestination,
|
||||||
LastSeen: time.Now(),
|
LastSeen: time.Now(),
|
||||||
}
|
})
|
||||||
|
|
||||||
s.wgSessions.Store(sessionKey, session)
|
|
||||||
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)",
|
||||||
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
sessionKey, pattern.ToDestination.String(), pattern.PacketCount)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user