Allow connecting and disconnecting

Former-commit-id: 596c4aa0da
This commit is contained in:
Owen
2025-11-07 14:07:44 -08:00
parent 963d8abad5
commit ce3c585514
2 changed files with 277 additions and 221 deletions

View File

@@ -16,6 +16,7 @@ type ConnectionRequest struct {
ID string `json:"id"`
Secret string `json:"secret"`
Endpoint string `json:"endpoint"`
UserToken string `json:"userToken,omitempty"`
}
// SwitchOrgRequest defines the structure for switching organizations
@@ -53,6 +54,7 @@ type API struct {
connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{}
disconnectChan chan struct{}
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
@@ -70,6 +72,7 @@ func NewAPI(addr string) *API {
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -83,6 +86,7 @@ func NewAPISocket(socketPath string) *API {
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
disconnectChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -95,6 +99,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/disconnect", s.handleDisconnect)
mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{
@@ -159,6 +164,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
// GetDisconnectChannel returns the channel for receiving disconnect requests
func (s *API) GetDisconnectChannel() <-chan struct{} {
return s.disconnectChan
}
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -356,3 +366,28 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
"status": "org switch request accepted",
})
}
// handleDisconnect handles the /disconnect endpoint
func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
logger.Info("Received disconnect request via API")
// Send disconnect signal
select {
case s.disconnectChan <- struct{}{}:
// Signal sent successfully
default:
// Channel already has a signal, don't block
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{
"status": "disconnect initiated",
})
}

View File

@@ -3,7 +3,6 @@ package olm
import (
"context"
"encoding/json"
"fmt"
"net"
"os"
"runtime"
@@ -39,10 +38,6 @@ type Config struct {
HTTPAddr string
SocketPath string
// Ping settings
PingInterval string
PingTimeout string
// Advanced
Holepunch bool
TlsClientCert string
@@ -58,22 +53,7 @@ type Config struct {
OrgID string
}
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Extract commonly used values from config for convenience
var (
endpoint = config.Endpoint
id = config.ID
secret = config.Secret
mtu = config.MTU
logLevel = config.LogLevel
interfaceName = config.InterfaceName
pingInterval = config.PingIntervalDuration
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
var (
privateKey wgtypes.Key
connected bool
dev *device.Device
@@ -81,28 +61,26 @@ func Run(ctx context.Context, config Config) {
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
)
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
)
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
defer cancel()
loggerLevel := parseLogLevel(logLevel)
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel))
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
logger.Debug("Failed to check for updates: %v", err)
}
// Log startup information
logger.Debug("Olm service starting...")
logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret)
if doHolepunch {
if config.Holepunch {
logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.")
}
var apiServer *api.API
if config.EnableAPI {
if config.HTTPAddr != "" {
apiServer = api.NewAPI(config.HTTPAddr)
} else if config.SocketPath != "" {
@@ -111,6 +89,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetVersion(config.Version)
apiServer.SetOrgID(config.OrgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
@@ -122,42 +101,56 @@ func Run(ctx context.Context, config Config) {
// Cancel the context to trigger graceful shutdown
cancel()
}()
}
// // Use a goroutine to handle connection requests
// go func() {
// for req := range apiServer.GetConnectionChannel() {
// logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// // Set the connection parameters
// id = req.ID
// secret = req.Secret
// endpoint = req.Endpoint
// }
// }()
// }
// Create a new olm
olm, err := websocket.NewClient(
"olm",
id, // CLI arg takes precedence
secret, // CLI arg takes precedence
endpoint,
pingInterval,
pingTimeout,
var (
id = config.ID
secret = config.Secret
endpoint = config.Endpoint
)
if err != nil {
logger.Fatal("Failed to create olm: %v", err)
}
// wait until we have a client id and secret and endpoint
waitCount := 0
for id == "" || secret == "" || endpoint == "" {
// Main event loop that handles connect, disconnect, and reconnect
for {
select {
case <-ctx.Done():
logger.Info("Context cancelled while waiting for credentials")
return
goto shutdown
case req := <-apiServer.GetConnectionChannel():
logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint)
// Stop any existing tunnel before starting a new one
if olmClient != nil {
logger.Info("Stopping existing tunnel before starting new connection")
StopTunnel()
}
// Set the connection parameters
id = req.ID
secret = req.Secret
endpoint = req.Endpoint
// Start the tunnel process with the new credentials
if id != "" && secret != "" && endpoint != "" {
logger.Info("Starting tunnel with new credentials")
go TunnelProcess(ctx, config, id, secret, endpoint)
}
case <-apiServer.GetDisconnectChannel():
logger.Info("Received disconnect request via API")
StopTunnel()
// Clear credentials so we wait for new connect call
id = ""
secret = ""
endpoint = ""
default:
// If we have credentials and no tunnel is running, start it
if id != "" && secret != "" && endpoint != "" && olmClient == nil {
logger.Info("Starting tunnel process with initial credentials")
go TunnelProcess(ctx, config, id, secret, endpoint)
} else if id == "" || secret == "" || endpoint == "" {
// If we don't have credentials, check if API is enabled
if !config.EnableAPI {
missing := []string{}
if id == "" {
missing = append(missing, "id")
@@ -168,23 +161,67 @@ func Run(ctx context.Context, config Config) {
if endpoint == "" {
missing = append(missing, "endpoint")
}
waitCount++
if waitCount%10 == 1 { // Log every 10 seconds instead of every second
logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount)
}
time.Sleep(1 * time.Second)
// exit the application because there is no way to provide the missing parameters
logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing)
goto shutdown
}
}
// Sleep briefly to prevent tight loop
time.Sleep(100 * time.Millisecond)
}
}
shutdown:
Stop()
apiServer.Stop()
logger.Info("Olm service shutting down")
}
func TunnelProcess(ctx context.Context, config Config, id string, secret string, endpoint string) {
// Create a cancellable context for this tunnel process
tunnelCtx, cancel := context.WithCancel(ctx)
tunnelCancel = cancel
defer func() {
tunnelCancel = nil
}()
// Recreate channels for this tunnel session
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
var (
interfaceName = config.InterfaceName
loggerLevel = parseLogLevel(config.LogLevel)
)
// Create a new olm client using the provided credentials
olm, err := websocket.NewClient(
"olm",
id, // Use provided ID
secret, // Use provided secret
endpoint, // Use provided endpoint
config.PingIntervalDuration,
config.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create olm: %v", err)
return
}
// Store the client reference globally
olmClient = olm
privateKey, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
logger.Error("Failed to generate private key: %v", err)
return
}
sourcePort, err := FindAvailableUDPPort(49152, 65535)
if err != nil {
fmt.Printf("Error finding available port: %v\n", err)
os.Exit(1)
logger.Error("Error finding available port: %v", err)
return
}
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
@@ -289,12 +326,12 @@ func Run(ctx context.Context, config Config) {
if err != nil {
return nil, err
}
return tun.CreateTUN(interfaceName, mtu)
return tun.CreateTUN(interfaceName, config.MTU)
}
if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" {
return createTUNFromFD(tunFdStr, mtu)
return createTUNFromFD(tunFdStr, config.MTU)
}
return tun.CreateTUN(interfaceName, mtu)
return tun.CreateTUN(interfaceName, config.MTU)
}()
if err != nil {
@@ -347,13 +384,10 @@ func Run(ctx context.Context, config Config) {
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if apiServer != nil {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information
var endpoint string
var isRelay bool
@@ -362,12 +396,11 @@ func Run(ctx context.Context, config Config) {
endpoint = site.Endpoint
// TODO: We'll need to track relay status separately
// For now, assume not using relay unless we get relay data
isRelay = !doHolepunch
isRelay = !config.Holepunch
break
}
}
apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
@@ -377,14 +410,12 @@ func Run(ctx context.Context, config Config) {
fixKey(privateKey.String()),
olm,
dev,
doHolepunch,
config.Holepunch,
)
for i := range wgData.Sites {
site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice
if apiServer != nil {
apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false)
}
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
@@ -407,9 +438,7 @@ func Run(ctx context.Context, config Config) {
peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true)
}
connected = true
@@ -637,9 +666,7 @@ func Run(ctx context.Context, config Config) {
}
// Update HTTP server to mark this peer as using relay
if apiServer != nil {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
@@ -670,9 +697,7 @@ func Run(ctx context.Context, config Config) {
olm.OnConnect(func() error {
logger.Info("Websocket Connected")
if apiServer != nil {
apiServer.SetConnectionStatus(true)
}
if connected {
logger.Debug("Already connected, skipping registration")
@@ -681,11 +706,11 @@ func Run(ctx context.Context, config Config) {
publicKey := privateKey.PublicKey()
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
@@ -700,8 +725,14 @@ func Run(ctx context.Context, config Config) {
olmToken = token
})
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Error("Failed to connect to server: %v", err)
return
}
defer olm.Close()
// Listen for org switch requests from the API
if apiServer != nil {
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
@@ -712,77 +743,32 @@ func Run(ctx context.Context, config Config) {
// Mark as not connected to trigger re-registration
connected = false
// Stop registration if running
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// Stop hole punching
select {
case <-stopHolepunch:
// Already closed
default:
close(stopHolepunch)
}
stopHolepunch = make(chan struct{})
// Stop peer monitor
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
// Close the WireGuard device
if dev != nil {
logger.Info("Closing existing WireGuard device for org switch")
dev.Close()
dev = nil
}
// Close UAPI listener
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
Stop()
// Clear peer statuses in API
if apiServer != nil {
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
apiServer.SetOrgID(config.OrgID)
}
stopHolepunch = make(chan struct{})
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"relay": !config.Holepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
}
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer olm.Close()
select {
case <-ctx.Done():
logger.Info("Context cancelled")
}
// Wait for context cancellation
<-tunnelCtx.Done()
logger.Info("Tunnel process context cancelled, cleaning up")
}
func Stop() {
select {
case <-stopHolepunch:
// Channel already closed, do nothing
@@ -790,11 +776,6 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch)
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
select {
case <-stopPing:
// Channel already closed
@@ -802,20 +783,60 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
if dev != nil {
dev.Close()
dev = nil
}
if apiServer != nil {
apiServer.Stop()
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
logger.Info("Olm service stopped")
}
// StopTunnel stops just the tunnel process and websocket connection
// without shutting down the entire application
func StopTunnel() {
logger.Info("Stopping tunnel process")
// Cancel the tunnel context if it exists
if tunnelCancel != nil {
tunnelCancel()
// Give it a moment to clean up
time.Sleep(200 * time.Millisecond)
}
// Close the websocket connection
if olmClient != nil {
olmClient.Close()
olmClient = nil
}
Stop()
// Reset the connected state
connected = false
// Update API server status
apiServer.SetConnectionStatus(false)
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
logger.Info("Tunnel process stopped")
}