From ea461e0bfb88290a24f496d94a3f45e7114795e1 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 23 Jan 2026 14:47:54 -0800 Subject: [PATCH] Prevent crashing on close before connect --- olm/connect.go | 18 ++++++++++++ olm/data.go | 18 ++++++++++++ olm/olm.go | 75 ++++++++++++++++++++++++++++++++++++++++++-------- olm/peer.go | 24 ++++++++++++++++ 4 files changed, 124 insertions(+), 11 deletions(-) diff --git a/olm/connect.go b/olm/connect.go index 575a8fd..3048cde 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -28,6 +28,12 @@ type OlmErrorData struct { func (o *Olm) handleConnect(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring connect message") + return + } + var wgData WgData if o.connected { @@ -218,6 +224,12 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { func (o *Olm) handleOlmError(msg websocket.WSMessage) { logger.Debug("Received olm error message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring olm error message") + return + } + var errorData OlmErrorData jsonData, err := json.Marshal(msg.Data) @@ -245,6 +257,12 @@ func (o *Olm) handleOlmError(msg websocket.WSMessage) { func (o *Olm) handleTerminate(msg websocket.WSMessage) { logger.Info("Received terminate message") + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring terminate message") + return + } + var errorData OlmErrorData jsonData, err := json.Marshal(msg.Data) diff --git a/olm/data.go b/olm/data.go index 35798c6..050a23f 100644 --- a/olm/data.go +++ b/olm/data.go @@ -13,6 +13,12 @@ import ( func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) @@ -48,6 +54,12 @@ func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) @@ -83,6 +95,12 @@ func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) diff --git a/olm/olm.go b/olm/olm.go index cd8a844..e3a9d77 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -8,6 +8,7 @@ import ( _ "net/http/pprof" "os" "sync" + "syscall" "time" "github.com/fosrl/newt/bind" @@ -66,6 +67,9 @@ type Olm struct { updateRegister func(newData any) stopPeerSend func() + + // WaitGroup to track tunnel lifecycle + tunnelWg sync.WaitGroup } // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -389,11 +393,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) { return nil } + // Check if tunnel is still running before starting registration + if !o.tunnelRunning { + logger.Debug("Tunnel is no longer running, skipping registration") + return nil + } + publicKey := o.privateKey.PublicKey() // delay for 500ms to allow for time for the hp to get processed time.Sleep(500 * time.Millisecond) + // Check again after sleep in case tunnel was stopped + if !o.tunnelRunning { + logger.Debug("Tunnel stopped during delay, skipping registration") + return nil + } + if o.stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{ @@ -417,6 +433,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) { }) o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { + // Check if tunnel is still running and hole punch manager exists + if !o.tunnelRunning || o.holePunchManager == nil { + logger.Debug("Tunnel stopped or hole punch manager nil, ignoring token update") + return + } + o.holePunchManager.SetToken(token) logger.Debug("Got exit nodes for hole punching: %v", exitNodes) @@ -447,6 +469,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) { }) o.websocket.OnAuthError(func(statusCode int, message string) { + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring auth error") + return + } + logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) o.apiServer.SetTerminated(true) o.apiServer.SetConnectionStatus(false) @@ -466,6 +494,10 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } }) + // Indicate that tunnel is starting + o.tunnelWg.Add(1) + defer o.tunnelWg.Done() + // Connect to the WebSocket server if err := o.websocket.Connect(); err != nil { logger.Error("Failed to connect to server: %v", err) @@ -479,6 +511,13 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } func (o *Olm) Close() { + // Stop registration first to prevent it from trying to use closed websocket + if o.stopRegister != nil { + logger.Debug("Stopping registration interval") + o.stopRegister() + o.stopRegister = nil + } + // send a disconnect message to the cloud to show disconnected if o.websocket != nil { o.websocket.SendMessage("olm/disconnecting", map[string]any{}) @@ -498,11 +537,6 @@ func (o *Olm) Close() { o.holePunchManager = nil } - if o.stopRegister != nil { - o.stopRegister() - o.stopRegister = nil - } - // Close() also calls Stop() internally if o.peerManager != nil { o.peerManager.Close() @@ -533,6 +567,21 @@ func (o *Olm) Close() { logger.Debug("Closing MiddleDevice") _ = o.middleDev.Close() o.middleDev = nil + } else if o.tdev != nil { + // If middleDev was never created but tdev exists, close it directly + logger.Debug("Closing TUN device directly (no MiddleDevice)") + _ = o.tdev.Close() + o.tdev = nil + } else if o.tunnelConfig.FileDescriptorTun != 0 { + // If we never created a device from the FD, close it explicitly + // This can happen if tunnel is stopped during registration before handleConnect + logger.Debug("Closing unused TUN file descriptor %d", o.tunnelConfig.FileDescriptorTun) + if err := syscall.Close(int(o.tunnelConfig.FileDescriptorTun)); err != nil { + logger.Error("Failed to close TUN file descriptor: %v", err) + } else { + logger.Info("Closed unused TUN file descriptor") + } + o.tunnelConfig.FileDescriptorTun = 0 } // Now close WireGuard device - its TUN reader should have exited by now @@ -565,20 +614,24 @@ func (o *Olm) StopTunnel() error { return nil } + // Reset the running state BEFORE cleanup to prevent callbacks from accessing nil pointers + o.connected = false + o.tunnelRunning = false + // Cancel the tunnel context if it exists if o.tunnelCancel != nil { + logger.Debug("Cancelling tunnel context") o.tunnelCancel() - // Give it a moment to clean up - time.Sleep(200 * time.Millisecond) } + // Wait for the tunnel goroutine to complete + logger.Debug("Waiting for tunnel goroutine to finish") + o.tunnelWg.Wait() + logger.Debug("Tunnel goroutine finished") + // Close() will handle sending disconnect message and closing websocket o.Close() - // Reset the connected state - o.connected = false - o.tunnelRunning = false - // Update API server status o.apiServer.SetConnectionStatus(false) o.apiServer.SetRegistered(false) diff --git a/olm/peer.go b/olm/peer.go index 56e298d..8007272 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -14,6 +14,12 @@ import ( func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { logger.Debug("Received add-peer message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring add-peer message") + return + } + if o.stopPeerSend != nil { o.stopPeerSend() o.stopPeerSend = nil @@ -44,6 +50,12 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { logger.Debug("Received remove-peer message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring remove-peer message") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) @@ -75,6 +87,12 @@ func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { logger.Debug("Received update-peer message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring update-peer message") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) @@ -199,6 +217,12 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { logger.Debug("Received peer-handshake message: %v", msg.Data) + // Check if tunnel is still running + if !o.tunnelRunning { + logger.Debug("Tunnel stopped, ignoring peer-handshake message") + return + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling handshake data: %v", err)