mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Prevent crashing on close before connect
This commit is contained in:
@@ -28,6 +28,12 @@ type OlmErrorData struct {
|
||||
func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring connect message")
|
||||
return
|
||||
}
|
||||
|
||||
var wgData WgData
|
||||
|
||||
if o.connected {
|
||||
@@ -218,6 +224,12 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
||||
logger.Debug("Received olm error message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring olm error message")
|
||||
return
|
||||
}
|
||||
|
||||
var errorData OlmErrorData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
@@ -245,6 +257,12 @@ func (o *Olm) handleOlmError(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleTerminate(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring terminate message")
|
||||
return
|
||||
}
|
||||
|
||||
var errorData OlmErrorData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
|
||||
18
olm/data.go
18
olm/data.go
@@ -13,6 +13,12 @@ import (
|
||||
func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring add-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
@@ -48,6 +54,12 @@ func (o *Olm) handleWgPeerAddData(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring remove-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
@@ -83,6 +95,12 @@ func (o *Olm) handleWgPeerRemoveData(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-remote-subnets-aliases message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring update-remote-subnets-aliases message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
|
||||
75
olm/olm.go
75
olm/olm.go
@@ -8,6 +8,7 @@ import (
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
@@ -66,6 +67,9 @@ type Olm struct {
|
||||
updateRegister func(newData any)
|
||||
|
||||
stopPeerSend func()
|
||||
|
||||
// WaitGroup to track tunnel lifecycle
|
||||
tunnelWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||
@@ -389,11 +393,23 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if tunnel is still running before starting registration
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel is no longer running, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
publicKey := o.privateKey.PublicKey()
|
||||
|
||||
// delay for 500ms to allow for time for the hp to get processed
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check again after sleep in case tunnel was stopped
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped during delay, skipping registration")
|
||||
return nil
|
||||
}
|
||||
|
||||
if o.stopRegister == nil {
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||
o.stopRegister, o.updateRegister = o.websocket.SendMessageInterval("olm/wg/register", map[string]any{
|
||||
@@ -417,6 +433,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
})
|
||||
|
||||
o.websocket.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
|
||||
// Check if tunnel is still running and hole punch manager exists
|
||||
if !o.tunnelRunning || o.holePunchManager == nil {
|
||||
logger.Debug("Tunnel stopped or hole punch manager nil, ignoring token update")
|
||||
return
|
||||
}
|
||||
|
||||
o.holePunchManager.SetToken(token)
|
||||
|
||||
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
|
||||
@@ -447,6 +469,12 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
})
|
||||
|
||||
o.websocket.OnAuthError(func(statusCode int, message string) {
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring auth error")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Authentication error (status %d): %s. Terminating tunnel.", statusCode, message)
|
||||
o.apiServer.SetTerminated(true)
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
@@ -466,6 +494,10 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
})
|
||||
|
||||
// Indicate that tunnel is starting
|
||||
o.tunnelWg.Add(1)
|
||||
defer o.tunnelWg.Done()
|
||||
|
||||
// Connect to the WebSocket server
|
||||
if err := o.websocket.Connect(); err != nil {
|
||||
logger.Error("Failed to connect to server: %v", err)
|
||||
@@ -479,6 +511,13 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
|
||||
func (o *Olm) Close() {
|
||||
// Stop registration first to prevent it from trying to use closed websocket
|
||||
if o.stopRegister != nil {
|
||||
logger.Debug("Stopping registration interval")
|
||||
o.stopRegister()
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
// send a disconnect message to the cloud to show disconnected
|
||||
if o.websocket != nil {
|
||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||
@@ -498,11 +537,6 @@ func (o *Olm) Close() {
|
||||
o.holePunchManager = nil
|
||||
}
|
||||
|
||||
if o.stopRegister != nil {
|
||||
o.stopRegister()
|
||||
o.stopRegister = nil
|
||||
}
|
||||
|
||||
// Close() also calls Stop() internally
|
||||
if o.peerManager != nil {
|
||||
o.peerManager.Close()
|
||||
@@ -533,6 +567,21 @@ func (o *Olm) Close() {
|
||||
logger.Debug("Closing MiddleDevice")
|
||||
_ = o.middleDev.Close()
|
||||
o.middleDev = nil
|
||||
} else if o.tdev != nil {
|
||||
// If middleDev was never created but tdev exists, close it directly
|
||||
logger.Debug("Closing TUN device directly (no MiddleDevice)")
|
||||
_ = o.tdev.Close()
|
||||
o.tdev = nil
|
||||
} else if o.tunnelConfig.FileDescriptorTun != 0 {
|
||||
// If we never created a device from the FD, close it explicitly
|
||||
// This can happen if tunnel is stopped during registration before handleConnect
|
||||
logger.Debug("Closing unused TUN file descriptor %d", o.tunnelConfig.FileDescriptorTun)
|
||||
if err := syscall.Close(int(o.tunnelConfig.FileDescriptorTun)); err != nil {
|
||||
logger.Error("Failed to close TUN file descriptor: %v", err)
|
||||
} else {
|
||||
logger.Info("Closed unused TUN file descriptor")
|
||||
}
|
||||
o.tunnelConfig.FileDescriptorTun = 0
|
||||
}
|
||||
|
||||
// Now close WireGuard device - its TUN reader should have exited by now
|
||||
@@ -565,20 +614,24 @@ func (o *Olm) StopTunnel() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset the running state BEFORE cleanup to prevent callbacks from accessing nil pointers
|
||||
o.connected = false
|
||||
o.tunnelRunning = false
|
||||
|
||||
// Cancel the tunnel context if it exists
|
||||
if o.tunnelCancel != nil {
|
||||
logger.Debug("Cancelling tunnel context")
|
||||
o.tunnelCancel()
|
||||
// Give it a moment to clean up
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Wait for the tunnel goroutine to complete
|
||||
logger.Debug("Waiting for tunnel goroutine to finish")
|
||||
o.tunnelWg.Wait()
|
||||
logger.Debug("Tunnel goroutine finished")
|
||||
|
||||
// Close() will handle sending disconnect message and closing websocket
|
||||
o.Close()
|
||||
|
||||
// Reset the connected state
|
||||
o.connected = false
|
||||
o.tunnelRunning = false
|
||||
|
||||
// Update API server status
|
||||
o.apiServer.SetConnectionStatus(false)
|
||||
o.apiServer.SetRegistered(false)
|
||||
|
||||
24
olm/peer.go
24
olm/peer.go
@@ -14,6 +14,12 @@ import (
|
||||
func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring add-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
if o.stopPeerSend != nil {
|
||||
o.stopPeerSend()
|
||||
o.stopPeerSend = nil
|
||||
@@ -44,6 +50,12 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||
logger.Debug("Received remove-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring remove-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
@@ -75,6 +87,12 @@ func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||
logger.Debug("Received update-peer message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring update-peer message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
@@ -199,6 +217,12 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
||||
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||
|
||||
// Check if tunnel is still running
|
||||
if !o.tunnelRunning {
|
||||
logger.Debug("Tunnel stopped, ignoring peer-handshake message")
|
||||
return
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling handshake data: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user