diff --git a/olm/olm.go b/olm/olm.go index 9b7ab66..4c067e8 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -273,41 +273,37 @@ func StartTunnel(config TunnelConfig) { } // Create shared UDP socket for both holepunch and WireGuard - if sharedBind == nil { - sourcePort, err := util.FindAvailableUDPPort(49152, 65535) - if err != nil { - logger.Error("Error finding available port: %v", err) - return - } - - localAddr := &net.UDPAddr{ - Port: int(sourcePort), - IP: net.IPv4zero, - } - - udpConn, err := net.ListenUDP("udp", localAddr) - if err != nil { - logger.Error("Failed to create shared UDP socket: %v", err) - return - } - - sharedBind, err = bind.New(udpConn) - if err != nil { - logger.Error("Failed to create shared bind: %v", err) - udpConn.Close() - return - } - - // Add a reference for the hole punch senders (creator already has one reference for WireGuard) - sharedBind.AddRef() - - logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) + if err != nil { + logger.Error("Error finding available port: %v", err) + return } + localAddr := &net.UDPAddr{ + Port: int(sourcePort), + IP: net.IPv4zero, + } + + udpConn, err := net.ListenUDP("udp", localAddr) + if err != nil { + logger.Error("Failed to create shared UDP socket: %v", err) + return + } + + sharedBind, err = bind.New(udpConn) + if err != nil { + logger.Error("Failed to create shared bind: %v", err) + udpConn.Close() + return + } + + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) + sharedBind.AddRef() + + logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) + // Create the holepunch manager - if holePunchManager == nil { - holePunchManager = holepunch.NewManager(sharedBind, id, "olm") - } + holePunchManager = holepunch.NewManager(sharedBind, id, "olm") olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -828,6 +824,7 @@ func Close() { // Stop hole punch manager if holePunchManager != nil { holePunchManager.Stop() + holePunchManager = nil } if stopPing != nil { @@ -853,10 +850,12 @@ func Close() { uapiListener.Close() uapiListener = nil } + if dev != nil { dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference dev = nil } + // Close TUN device if tdev != nil { tdev.Close()