Prevent crashing on close before connect

Former-commit-id: ea461e0bfb
This commit is contained in:
Owen
2026-01-23 14:47:54 -08:00
parent 6ae4e2b691
commit ba2631d388
4 changed files with 124 additions and 11 deletions

View File

@@ -28,6 +28,12 @@ type OlmErrorData struct {
func (o *Olm) handleConnect(msg websocket.WSMessage) { func (o *Olm) handleConnect(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data) logger.Debug("Received message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring connect message")
return
}
var wgData WgData var wgData WgData
if o.connected { if o.connected {
@@ -218,6 +224,12 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
func (o *Olm) handleOlmError(msg websocket.WSMessage) { func (o *Olm) handleOlmError(msg websocket.WSMessage) {
logger.Debug("Received olm error message: %v", msg.Data) logger.Debug("Received olm error message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring olm error message")
return
}
var errorData OlmErrorData var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
@@ -245,6 +257,12 @@ func (o *Olm) handleOlmError(msg websocket.WSMessage) {
func (o *Olm) handleTerminate(msg websocket.WSMessage) { func (o *Olm) handleTerminate(msg websocket.WSMessage) {
logger.Info("Received terminate message") logger.Info("Received terminate message")
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring terminate message")
return
}
var errorData OlmErrorData var errorData OlmErrorData
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)

View File

@@ -13,6 +13,12 @@ import (
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) { func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data) logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
@@ -48,6 +54,12 @@ func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data) logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
@@ -83,6 +95,12 @@ func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data) logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
return
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)

View File

@@ -8,6 +8,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"sync" "sync"
"syscall"
"time" "time"
"github.com/fosrl/newt/bind" "github.com/fosrl/newt/bind"
@@ -66,6 +67,9 @@ type Olm struct {
updateRegister func(newData any) updateRegister func(newData any)
stopPeerSend func() stopPeerSend func()
// WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup
} }
// initTunnelInfo creates the shared UDP socket and holepunch manager. // initTunnelInfo creates the shared UDP socket and holepunch manager.
@@ -389,11 +393,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
return nil return nil
} }
// Check if tunnel is still running before starting registration
if !o.tunnelRunning {
logger.Debug("Tunnel is no longer running, skipping registration")
return nil
}
publicKey := o.privateKey.PublicKey() publicKey := o.privateKey.PublicKey()
// delay for 500ms to allow for time for the hp to get processed // delay for 500ms to allow for time for the hp to get processed
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
// Check again after sleep in case tunnel was stopped
if !o.tunnelRunning {
logger.Debug("Tunnel stopped during delay, skipping registration")
return nil
}
if o.stopRegister == nil { if o.stopRegister == nil {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{ o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{
@@ -417,6 +433,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}) })
o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
// Check if tunnel is still running and hole punch manager exists
if !o.tunnelRunning || o.holePunchManager == nil {
logger.Debug("Tunnel stopped or hole punch manager nil, ignoring token update")
return
}
o.holePunchManager.SetToken(token) o.holePunchManager.SetToken(token)
logger.Debug("Got exit nodes for hole punching: %v", exitNodes) logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
@@ -447,6 +469,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}) })
o.websocket.OnAuthError(func(statusCode int, message string) { o.websocket.OnAuthError(func(statusCode int, message string) {
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring auth error")
return
}
logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message) logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message)
o.apiServer.SetTerminated(true) o.apiServer.SetTerminated(true)
o.apiServer.SetConnectionStatus(false) o.apiServer.SetConnectionStatus(false)
@@ -466,6 +494,10 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
} }
}) })
// Indicate that tunnel is starting
o.tunnelWg.Add(1)
defer o.tunnelWg.Done()
// Connect to the WebSocket server // Connect to the WebSocket server
if err := o.websocket.Connect(); err != nil { if err := o.websocket.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err) logger.Error("Failed to connect to server: %v", err)
@@ -479,6 +511,13 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
} }
func (o *Olm) Close() { func (o *Olm) Close() {
// Stop registration first to prevent it from trying to use closed websocket
if o.stopRegister != nil {
logger.Debug("Stopping registration interval")
o.stopRegister()
o.stopRegister = nil
}
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected
if o.websocket != nil { if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{}) o.websocket.SendMessage("olm/disconnecting", map[string]any{})
@@ -498,11 +537,6 @@ func (o *Olm) Close() {
o.holePunchManager = nil o.holePunchManager = nil
} }
if o.stopRegister != nil {
o.stopRegister()
o.stopRegister = nil
}
// Close() also calls Stop() internally // Close() also calls Stop() internally
if o.peerManager != nil { if o.peerManager != nil {
o.peerManager.Close() o.peerManager.Close()
@@ -533,6 +567,21 @@ func (o *Olm) Close() {
logger.Debug("Closing MiddleDevice") logger.Debug("Closing MiddleDevice")
_ = o.middleDev.Close() _ = o.middleDev.Close()
o.middleDev = nil o.middleDev = nil
} else if o.tdev != nil {
// If middleDev was never created but tdev exists, close it directly
logger.Debug("Closing TUN device directly (no MiddleDevice)")
_ = o.tdev.Close()
o.tdev = nil
} else if o.tunnelConfig.FileDescriptorTun != 0 {
// If we never created a device from the FD, close it explicitly
// This can happen if tunnel is stopped during registration before handleConnect
logger.Debug("Closing unused TUN file descriptor %d", o.tunnelConfig.FileDescriptorTun)
if err := syscall.Close(int(o.tunnelConfig.FileDescriptorTun)); err != nil {
logger.Error("Failed to close TUN file descriptor: %v", err)
} else {
logger.Info("Closed unused TUN file descriptor")
}
o.tunnelConfig.FileDescriptorTun = 0
} }
// Now close WireGuard device - its TUN reader should have exited by now // Now close WireGuard device - its TUN reader should have exited by now
@@ -565,20 +614,24 @@ func (o *Olm) StopTunnel() error {
return nil return nil
} }
// Reset the running state BEFORE cleanup to prevent callbacks from accessing nil pointers
o.connected = false
o.tunnelRunning = false
// Cancel the tunnel context if it exists // Cancel the tunnel context if it exists
if o.tunnelCancel != nil { if o.tunnelCancel != nil {
logger.Debug("Cancelling tunnel context")
o.tunnelCancel() o.tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
} }
// Wait for the tunnel goroutine to complete
logger.Debug("Waiting for tunnel goroutine to finish")
o.tunnelWg.Wait()
logger.Debug("Tunnel goroutine finished")
// Close() will handle sending disconnect message and closing websocket // Close() will handle sending disconnect message and closing websocket
o.Close() o.Close()
// Reset the connected state
o.connected = false
o.tunnelRunning = false
// Update API server status // Update API server status
o.apiServer.SetConnectionStatus(false) o.apiServer.SetConnectionStatus(false)
o.apiServer.SetRegistered(false) o.apiServer.SetRegistered(false)

View File

@@ -14,6 +14,12 @@ import (
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
logger.Debug("Received add-peer message: %v", msg.Data) logger.Debug("Received add-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring add-peer message")
return
}
if o.stopPeerSend != nil { if o.stopPeerSend != nil {
o.stopPeerSend() o.stopPeerSend()
o.stopPeerSend = nil o.stopPeerSend = nil
@@ -44,6 +50,12 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
logger.Debug("Received remove-peer message: %v", msg.Data) logger.Debug("Received remove-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring remove-peer message")
return
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
@@ -75,6 +87,12 @@ func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
logger.Debug("Received update-peer message: %v", msg.Data) logger.Debug("Received update-peer message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring update-peer message")
return
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
@@ -199,6 +217,12 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
logger.Debug("Received peer-handshake message: %v", msg.Data) logger.Debug("Received peer-handshake message: %v", msg.Data)
// Check if tunnel is still running
if !o.tunnelRunning {
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
return
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling handshake data: %v", err) logger.Error("Error marshaling handshake data: %v", err)