Allow connecting and disconnecting

This commit is contained in:
Owen
2025-11-07 14:07:44 -08:00
parent da1e4911bd
commit 596c4aa0da
2 changed files with 277 additions and 221 deletions

View File

@@ -16,6 +16,7 @@ type ConnectionRequest struct {
ID string `json:"id"` ID string `json:"id"`
Secret string `json:"secret"` Secret string `json:"secret"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
} }
// SwitchOrgRequest defines the structure for switching organizations // SwitchOrgRequest defines the structure for switching organizations
@@ -53,6 +54,7 @@ type API struct {
connectionChan chan ConnectionRequest connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{} shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
connectedAt time.Time connectedAt time.Time
@@ -70,6 +72,7 @@ func NewAPI(addr string) *API {
connectionChan: make(chan ConnectionRequest, 1), connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1), switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1), shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus), peerStatuses: make(map[int]*PeerStatus),
} }
@@ -83,6 +86,7 @@ func NewAPISocket(socketPath string) *API {
connectionChan: make(chan ConnectionRequest, 1), connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1), switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1), shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus), peerStatuses: make(map[int]*PeerStatus),
} }
@@ -95,6 +99,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg) mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{ s.server = &http.Server{
@@ -159,6 +164,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan return s.shutdownChan
} }
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info // UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock() s.statusMu.Lock()
@@ -356,3 +366,28 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
"status": "org switch request accepted", "status": "org switch request accepted",
}) })
} }
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}

View File

@@ -3,7 +3,6 @@ package olm
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"os" "os"
"runtime" "runtime"
@@ -39,10 +38,6 @@ type Config struct {
HTTPAddr string HTTPAddr string
SocketPath string SocketPath string
// Ping settings
PingInterval string
PingTimeout string
// Advanced // Advanced
Holepunch bool Holepunch bool
TlsClientCert string TlsClientCert string
@@ -58,22 +53,7 @@ type Config struct {
OrgID string OrgID string
} }
func Run(ctx context.Context, config Config) { var (
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Extract commonly used values from config for convenience
var (
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
mtu = config.MTU
logLevel = config.LogLevel
interfaceName = config.InterfaceName
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key privateKey wgtypes.Key
connected bool connected bool
dev *device.Device dev *device.Device
@@ -81,28 +61,26 @@ func Run(ctx context.Context, config Config) {
holePunchData HolePunchData holePunchData HolePunchData
uapiListener net.Listener uapiListener net.Listener
tdev tun.Device tdev tun.Device
) apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
)
stopHolepunch = make(chan struct{}) func Run(ctx context.Context, config Config) {
stopPing = make(chan struct{}) // Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil { if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err) logger.Debug("Failed to check for updates: %v", err)
} }
// Log startup information if config.Holepunch {
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
} }
var apiServer *api.API
if config.EnableAPI {
if config.HTTPAddr != "" { if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr) apiServer = api.NewAPI(config.HTTPAddr)
} else if config.SocketPath != "" { } else if config.SocketPath != "" {
@@ -111,6 +89,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetVersion(config.Version) apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID) apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil { if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err) logger.Fatal("Failed to start HTTP server: %v", err)
} }
@@ -122,42 +101,56 @@ func Run(ctx context.Context, config Config) {
// Cancel the context to trigger graceful shutdown // Cancel the context to trigger graceful shutdown
cancel() cancel()
}() }()
}
// // Use a goroutine to handle connection requests var (
// go func() { id = config.ID
// for req := range apiServer.GetConnectionChannel() { secret = config.Secret
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) endpoint = config.Endpoint
// // Set the connection parameters
// id = req.ID
// secret = req.Secret
// endpoint = req.Endpoint
// }
// }()
// }
// Create a new olm
olm, err := websocket.NewClient(
"olm",
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
pingTimeout,
) )
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
}
// wait until we have a client id and secret and endpoint // Main event loop that handles connect, disconnect, and reconnect
waitCount := 0 for {
for id == "" || secret == "" || endpoint == "" {
select { select {
case <-ctx.Done(): case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials") logger.Info("Context cancelled while waiting for credentials")
return goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
if olmClient != nil {
logger.Info("Stopping existing tunnel before starting new connection")
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials")
go TunnelProcess(ctx, config, id, secret, endpoint)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
default: default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && olmClient == nil {
logger.Info("Starting tunnel process with initial credentials")
go TunnelProcess(ctx, config, id, secret, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{} missing := []string{}
if id == "" { if id == "" {
missing = append(missing, "id") missing = append(missing, "id")
@@ -168,23 +161,67 @@ func Run(ctx context.Context, config Config) {
if endpoint == "" { if endpoint == "" {
missing = append(missing, "endpoint") missing = append(missing, "endpoint")
} }
waitCount++ // exit the application because there is no way to provide the missing parameters
if waitCount%10 == 1 { // Log every 10 seconds instead of every second logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount) goto shutdown
}
time.Sleep(1 * time.Second)
} }
} }
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Stop()
apiServer.Stop()
logger.Info("Olm service shutting down")
}
func TunnelProcess(ctx context.Context, config Config, id string, secret string, endpoint string) {
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
}()
// Recreate channels for this tunnel session
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
var (
interfaceName = config.InterfaceName
loggerLevel = parseLogLevel(config.LogLevel)
)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
"olm",
id, // Use provided ID
secret, // Use provided secret
endpoint, // Use provided endpoint
config.PingIntervalDuration,
config.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create olm: %v", err)
return
}
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey() privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Error("Failed to generate private key: %v", err)
return
} }
sourcePort, err := FindAvailableUDPPort(49152, 65535) sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil { if err != nil {
fmt.Printf("Error finding available port: %v\n", err) logger.Error("Error finding available port: %v", err)
os.Exit(1) return
} }
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
@@ -289,12 +326,12 @@ func Run(ctx context.Context, config Config) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tun.CreateTUN(interfaceName, mtu) return tun.CreateTUN(interfaceName, config.MTU)
} }
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" { if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtu) return createTUNFromFD(tunFdStr, config.MTU)
} }
return tun.CreateTUN(interfaceName, mtu) return tun.CreateTUN(interfaceName, config.MTU)
}() }()
if err != nil { if err != nil {
@@ -347,13 +384,10 @@ func Run(ctx context.Context, config Config) {
if err = ConfigureInterface(interfaceName, wgData); err != nil { if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err) logger.Error("Failed to configure interface: %v", err)
} }
if apiServer != nil {
apiServer.SetTunnelIP(wgData.TunnelIP) apiServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor( peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) { func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information // Find the site config to get endpoint information
var endpoint string var endpoint string
var isRelay bool var isRelay bool
@@ -362,12 +396,11 @@ func Run(ctx context.Context, config Config) {
endpoint = site.Endpoint endpoint = site.Endpoint
// TODO: We'll need to track relay status separately // TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data // For now, assume not using relay unless we get relay data
isRelay = !doHolepunch isRelay = !config.Holepunch
break break
} }
} }
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected { if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else { } else {
@@ -377,14 +410,12 @@ func Run(ctx context.Context, config Config) {
fixKey(privateKey.String()), fixKey(privateKey.String()),
olm, olm,
dev, dev,
doHolepunch, config.Holepunch,
) )
for i := range wgData.Sites { for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if apiServer != nil {
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer. // Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint) site.Endpoint = formatEndpoint(site.Endpoint)
@@ -407,9 +438,7 @@ func Run(ctx context.Context, config Config) {
peerMonitor.Start() peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true) apiServer.SetRegistered(true)
}
connected = true connected = true
@@ -637,9 +666,7 @@ func Run(ctx context.Context, config Config) {
} }
// Update HTTP server to mark this peer as using relay // Update HTTP server to mark this peer as using relay
if apiServer != nil {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
}) })
@@ -670,9 +697,7 @@ func Run(ctx context.Context, config Config) {
olm.OnConnect(func() error { olm.OnConnect(func() error {
logger.Info("Websocket Connected") logger.Info("Websocket Connected")
if apiServer != nil {
apiServer.SetConnectionStatus(true) apiServer.SetConnectionStatus(true)
}
if connected { if connected {
logger.Debug("Already connected, skipping registration") logger.Debug("Already connected, skipping registration")
@@ -681,11 +706,11 @@ func Run(ctx context.Context, config Config) {
publicKey := privateKey.PublicKey() publicKey := privateKey.PublicKey()
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"relay": !doHolepunch, "relay": !config.Holepunch,
"olmVersion": config.Version, "olmVersion": config.Version,
"orgId": config.OrgID, "orgId": config.OrgID,
}, 1*time.Second) }, 1*time.Second)
@@ -700,8 +725,14 @@ func Run(ctx context.Context, config Config) {
olmToken = token olmToken = token
}) })
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err)
return
}
defer olm.Close()
// Listen for org switch requests from the API // Listen for org switch requests from the API
if apiServer != nil {
go func() { go func() {
for req := range apiServer.GetSwitchOrgChannel() { for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID) logger.Info("Processing org switch request to orgId: %s", req.OrgID)
@@ -712,77 +743,32 @@ func Run(ctx context.Context, config Config) {
// Mark as not connected to trigger re-registration // Mark as not connected to trigger re-registration
connected = false connected = false
// Stop registration if running Stop()
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// Stop hole punching
select {
case <-stopHolepunch:
// Already closed
default:
close(stopHolepunch)
}
stopHolepunch = make(chan struct{})
// Stop peer monitor
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
// Close the WireGuard device
if dev != nil {
logger.Info("Closing existing WireGuard device for org switch")
dev.Close()
dev = nil
}
// Close UAPI listener
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
// Clear peer statuses in API // Clear peer statuses in API
if apiServer != nil {
apiServer.SetRegistered(false) apiServer.SetRegistered(false)
apiServer.SetTunnelIP("") apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID) apiServer.SetOrgID(config.OrgID)
}
stopHolepunch = make(chan struct{})
// Trigger re-registration with new orgId // Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID) logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey() publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"relay": !doHolepunch, "relay": !config.Holepunch,
"olmVersion": config.Version, "olmVersion": config.Version,
"orgId": config.OrgID, "orgId": config.OrgID,
}, 1*time.Second) }, 1*time.Second)
} }
}() }()
}
// Connect to the WebSocket server // Wait for context cancellation
if err := olm.Connect(); err != nil { <-tunnelCtx.Done()
logger.Fatal("Failed to connect to server: %v", err) logger.Info("Tunnel process context cancelled, cleaning up")
} }
defer olm.Close()
select {
case <-ctx.Done():
logger.Info("Context cancelled")
}
func Stop() {
select { select {
case <-stopHolepunch: case <-stopHolepunch:
// Channel already closed, do nothing // Channel already closed, do nothing
@@ -790,11 +776,6 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch) close(stopHolepunch)
} }
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
select { select {
case <-stopPing: case <-stopPing:
// Channel already closed // Channel already closed
@@ -802,20 +783,60 @@ func Run(ctx context.Context, config Config) {
close(stopPing) close(stopPing)
} }
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
if peerMonitor != nil { if peerMonitor != nil {
peerMonitor.Stop() peerMonitor.Stop()
peerMonitor = nil
} }
if uapiListener != nil { if uapiListener != nil {
uapiListener.Close() uapiListener.Close()
uapiListener = nil
} }
if dev != nil { if dev != nil {
dev.Close() dev.Close()
dev = nil
} }
// Close TUN device
if apiServer != nil { if tdev != nil {
apiServer.Stop() tdev.Close()
tdev = nil
} }
logger.Info("Olm service stopped") logger.Info("Olm service stopped")
} }
// StopTunnel stops just the tunnel process and websocket connection
// without shutting down the entire application
func StopTunnel() {
logger.Info("Stopping tunnel process")
// Cancel the tunnel context if it exists
if tunnelCancel != nil {
tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
}
// Close the websocket connection
if olmClient != nil {
olmClient.Close()
olmClient = nil
}
Stop()
// Reset the connected state
connected = false
// Update API server status
apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
logger.Info("Tunnel process stopped")
}