Allow connecting and disconnecting

This commit is contained in:
Owen
2025-11-07 14:07:44 -08:00
parent da1e4911bd
commit 596c4aa0da
2 changed files with 277 additions and 221 deletions

View File

@@ -13,9 +13,10 @@ import (
// ConnectionRequest defines the structure for an incoming connection request // ConnectionRequest defines the structure for an incoming connection request
type ConnectionRequest struct { type ConnectionRequest struct {
ID string `json:"id"` ID string `json:"id"`
Secret string `json:"secret"` Secret string `json:"secret"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
} }
// SwitchOrgRequest defines the structure for switching organizations // SwitchOrgRequest defines the structure for switching organizations
@@ -53,6 +54,7 @@ type API struct {
connectionChan chan ConnectionRequest connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{} shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus peerStatuses map[int]*PeerStatus
connectedAt time.Time connectedAt time.Time
@@ -70,6 +72,7 @@ func NewAPI(addr string) *API {
connectionChan: make(chan ConnectionRequest, 1), connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1), switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1), shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus), peerStatuses: make(map[int]*PeerStatus),
} }
@@ -83,6 +86,7 @@ func NewAPISocket(socketPath string) *API {
connectionChan: make(chan ConnectionRequest, 1), connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1), switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1), shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus), peerStatuses: make(map[int]*PeerStatus),
} }
@@ -95,6 +99,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg) mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit) mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{ s.server = &http.Server{
@@ -159,6 +164,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan return s.shutdownChan
} }
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info // UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock() s.statusMu.Lock()
@@ -356,3 +366,28 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
"status": "org switch request accepted", "status": "org switch request accepted",
}) })
} }
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}

View File

@@ -3,7 +3,6 @@ package olm
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"os" "os"
"runtime" "runtime"
@@ -39,10 +38,6 @@ type Config struct {
HTTPAddr string HTTPAddr string
SocketPath string SocketPath string
// Ping settings
PingInterval string
PingTimeout string
// Advanced // Advanced
Holepunch bool Holepunch bool
TlsClientCert string TlsClientCert string
@@ -58,133 +53,175 @@ type Config struct {
OrgID string OrgID string
} }
var (
privateKey wgtypes.Key
connected bool
dev *device.Device
wgData WgData
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
)
func Run(ctx context.Context, config Config) { func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control // Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
// Extract commonly used values from config for convenience logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
var (
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
mtu = config.MTU
logLevel = config.LogLevel
interfaceName = config.InterfaceName
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
connected bool
dev *device.Device
wgData WgData
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
)
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil { if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err) logger.Debug("Failed to check for updates: %v", err)
} }
// Log startup information if config.Holepunch {
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
if doHolepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
} }
var apiServer *api.API if config.HTTPAddr != "" {
if config.EnableAPI { apiServer = api.NewAPI(config.HTTPAddr)
if config.HTTPAddr != "" { } else if config.SocketPath != "" {
apiServer = api.NewAPI(config.HTTPAddr) apiServer = api.NewAPISocket(config.SocketPath)
} else if config.SocketPath != "" {
apiServer = api.NewAPISocket(config.SocketPath)
}
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
// Listen for shutdown requests from the API
go func() {
<-apiServer.GetShutdownChannel()
logger.Info("Shutdown requested via API")
// Cancel the context to trigger graceful shutdown
cancel()
}()
} }
// // Use a goroutine to handle connection requests apiServer.SetVersion(config.Version)
// go func() { apiServer.SetOrgID(config.OrgID)
// for req := range apiServer.GetConnectionChannel() {
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// // Set the connection parameters if err := apiServer.Start(); err != nil {
// id = req.ID logger.Fatal("Failed to start HTTP server: %v", err)
// secret = req.Secret }
// endpoint = req.Endpoint
// }
// }()
// }
// Create a new olm // Listen for shutdown requests from the API
olm, err := websocket.NewClient( go func() {
"olm", <-apiServer.GetShutdownChannel()
id, // CLI arg takes precedence logger.Info("Shutdown requested via API")
secret, // CLI arg takes precedence // Cancel the context to trigger graceful shutdown
endpoint, cancel()
pingInterval, }()
pingTimeout,
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
) )
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
}
// wait until we have a client id and secret and endpoint // Main event loop that handles connect, disconnect, and reconnect
waitCount := 0 for {
for id == "" || secret == "" || endpoint == "" {
select { select {
case <-ctx.Done(): case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials") logger.Info("Context cancelled while waiting for credentials")
return goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
if olmClient != nil {
logger.Info("Stopping existing tunnel before starting new connection")
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials")
go TunnelProcess(ctx, config, id, secret, endpoint)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
default: default:
missing := []string{} // If we have credentials and no tunnel is running, start it
if id == "" { if id != "" && secret != "" && endpoint != "" && olmClient == nil {
missing = append(missing, "id") logger.Info("Starting tunnel process with initial credentials")
go TunnelProcess(ctx, config, id, secret, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
}
if secret == "" {
missing = append(missing, "secret")
}
if endpoint == "" {
missing = append(missing, "endpoint")
}
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
} }
if secret == "" {
missing = append(missing, "secret") // Sleep briefly to prevent tight loop
} time.Sleep(100 * time.Millisecond)
if endpoint == "" {
missing = append(missing, "endpoint")
}
waitCount++
if waitCount%10 == 1 { // Log every 10 seconds instead of every second
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
}
time.Sleep(1 * time.Second)
} }
} }
shutdown:
Stop()
apiServer.Stop()
logger.Info("Olm service shutting down")
}
func TunnelProcess(ctx context.Context, config Config, id string, secret string, endpoint string) {
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
}()
// Recreate channels for this tunnel session
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
var (
interfaceName = config.InterfaceName
loggerLevel = parseLogLevel(config.LogLevel)
)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
"olm",
id, // Use provided ID
secret, // Use provided secret
endpoint, // Use provided endpoint
config.PingIntervalDuration,
config.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create olm: %v", err)
return
}
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey() privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Error("Failed to generate private key: %v", err)
return
} }
sourcePort, err := FindAvailableUDPPort(49152, 65535) sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil { if err != nil {
fmt.Printf("Error finding available port: %v\n", err) logger.Error("Error finding available port: %v", err)
os.Exit(1) return
} }
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
@@ -289,12 +326,12 @@ func Run(ctx context.Context, config Config) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tun.CreateTUN(interfaceName, mtu) return tun.CreateTUN(interfaceName, config.MTU)
} }
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" { if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtu) return createTUNFromFD(tunFdStr, config.MTU)
} }
return tun.CreateTUN(interfaceName, mtu) return tun.CreateTUN(interfaceName, config.MTU)
}() }()
if err != nil { if err != nil {
@@ -347,27 +384,23 @@ func Run(ctx context.Context, config Config) {
if err = ConfigureInterface(interfaceName, wgData); err != nil { if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err) logger.Error("Failed to configure interface: %v", err)
} }
if apiServer != nil { apiServer.SetTunnelIP(wgData.TunnelIP)
apiServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor( peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) { func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil { // Find the site config to get endpoint information
// Find the site config to get endpoint information var endpoint string
var endpoint string var isRelay bool
var isRelay bool for _, site := range wgData.Sites {
for _, site := range wgData.Sites { if site.SiteId == siteID {
if site.SiteId == siteID { endpoint = site.Endpoint
endpoint = site.Endpoint // TODO: We'll need to track relay status separately
// TODO: We'll need to track relay status separately // For now, assume not using relay unless we get relay data
// For now, assume not using relay unless we get relay data isRelay = !config.Holepunch
isRelay = !doHolepunch break
break
}
} }
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
} }
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
if connected { if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else { } else {
@@ -377,14 +410,12 @@ func Run(ctx context.Context, config Config) {
fixKey(privateKey.String()), fixKey(privateKey.String()),
olm, olm,
dev, dev,
doHolepunch, config.Holepunch,
) )
for i := range wgData.Sites { for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if apiServer != nil { apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer. // Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint) site.Endpoint = formatEndpoint(site.Endpoint)
@@ -407,9 +438,7 @@ func Run(ctx context.Context, config Config) {
peerMonitor.Start() peerMonitor.Start()
if apiServer != nil { apiServer.SetRegistered(true)
apiServer.SetRegistered(true)
}
connected = true connected = true
@@ -637,9 +666,7 @@ func Run(ctx context.Context, config Config) {
} }
// Update HTTP server to mark this peer as using relay // Update HTTP server to mark this peer as using relay
if apiServer != nil { apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
}) })
@@ -670,9 +697,7 @@ func Run(ctx context.Context, config Config) {
olm.OnConnect(func() error { olm.OnConnect(func() error {
logger.Info("Websocket Connected") logger.Info("Websocket Connected")
if apiServer != nil { apiServer.SetConnectionStatus(true)
apiServer.SetConnectionStatus(true)
}
if connected { if connected {
logger.Debug("Already connected, skipping registration") logger.Debug("Already connected, skipping registration")
@@ -681,11 +706,11 @@ func Run(ctx context.Context, config Config) {
publicKey := privateKey.PublicKey() publicKey := privateKey.PublicKey()
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"relay": !doHolepunch, "relay": !config.Holepunch,
"olmVersion": config.Version, "olmVersion": config.Version,
"orgId": config.OrgID, "orgId": config.OrgID,
}, 1*time.Second) }, 1*time.Second)
@@ -700,89 +725,50 @@ func Run(ctx context.Context, config Config) {
olmToken = token olmToken = token
}) })
// Listen for org switch requests from the API
if apiServer != nil {
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
// Stop registration if running
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// Stop hole punching
select {
case <-stopHolepunch:
// Already closed
default:
close(stopHolepunch)
}
stopHolepunch = make(chan struct{})
// Stop peer monitor
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
// Close the WireGuard device
if dev != nil {
logger.Info("Closing existing WireGuard device for org switch")
dev.Close()
dev = nil
}
// Close UAPI listener
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
// Clear peer statuses in API
if apiServer != nil {
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
}
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
}
// Connect to the WebSocket server // Connect to the WebSocket server
if err := olm.Connect(); err != nil { if err := olm.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err) logger.Error("Failed to connect to server: %v", err)
return
} }
defer olm.Close() defer olm.Close()
select { // Listen for org switch requests from the API
case <-ctx.Done(): go func() {
logger.Info("Context cancelled") for req := range apiServer.GetSwitchOrgChannel() {
} logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
Stop()
// Clear peer statuses in API
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
stopHolepunch = make(chan struct{})
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")
}
func Stop() {
select { select {
case <-stopHolepunch: case <-stopHolepunch:
// Channel already closed, do nothing // Channel already closed, do nothing
@@ -790,11 +776,6 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch) close(stopHolepunch)
} }
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
select { select {
case <-stopPing: case <-stopPing:
// Channel already closed // Channel already closed
@@ -802,20 +783,60 @@ func Run(ctx context.Context, config Config) {
close(stopPing) close(stopPing)
} }
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
if peerMonitor != nil { if peerMonitor != nil {
peerMonitor.Stop() peerMonitor.Stop()
peerMonitor = nil
} }
if uapiListener != nil { if uapiListener != nil {
uapiListener.Close() uapiListener.Close()
uapiListener = nil
} }
if dev != nil { if dev != nil {
dev.Close() dev.Close()
dev = nil
} }
// Close TUN device
if apiServer != nil { if tdev != nil {
apiServer.Stop() tdev.Close()
tdev = nil
} }
logger.Info("Olm service stopped") logger.Info("Olm service stopped")
} }
// StopTunnel stops just the tunnel process and websocket connection
// without shutting down the entire application
func StopTunnel() {
logger.Info("Stopping tunnel process")
// Cancel the tunnel context if it exists
if tunnelCancel != nil {
tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
}
// Close the websocket connection
if olmClient != nil {
olmClient.Close()
olmClient = nil
}
Stop()
// Reset the connected state
connected = false
// Update API server status
apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
logger.Info("Tunnel process stopped")
}