From d7345c7dbd144d22644efc877f1b89f67d83ad8c Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 18 Nov 2025 18:14:21 -0500 Subject: [PATCH] Split up concerns so parent can call start and stop Former-commit-id: 8f97c43b63a0f6d7a71a27e8aa293a47caea7cd2 --- api/api.go | 144 +++++++++++++------------ main.go | 53 ++++++---- olm/interface.go | 3 +- olm/olm.go | 268 +++++++++++++++++++++++++---------------------- 4 files changed, 246 insertions(+), 222 deletions(-) diff --git a/api/api.go b/api/api.go index a79e20f..a370b82 100644 --- a/api/api.go +++ b/api/api.go @@ -13,10 +13,18 @@ import ( // ConnectionRequest defines the structure for an incoming connection request type ConnectionRequest struct { - ID string `json:"id"` - Secret string `json:"secret"` - Endpoint string `json:"endpoint"` - UserToken string `json:"userToken,omitempty"` + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` + UserToken string `json:"userToken,omitempty"` + MTU int `json:"mtu,omitempty"` + DNS string `json:"dns,omitempty"` + InterfaceName string `json:"interfaceName,omitempty"` + Holepunch bool `json:"holepunch,omitempty"` + TlsClientCert string `json:"tlsClientCert,omitempty"` + PingInterval string `json:"pingInterval,omitempty"` + PingTimeout string `json:"pingTimeout,omitempty"` + OrgID string `json:"orgId,omitempty"` } // SwitchOrgRequest defines the structure for switching organizations @@ -47,33 +55,29 @@ type StatusResponse struct { // API represents the HTTP server and its state type API struct { - addr string - socketPath string - listener net.Listener - server *http.Server - connectionChan chan ConnectionRequest - switchOrgChan chan SwitchOrgRequest - shutdownChan chan struct{} - disconnectChan chan struct{} - statusMu sync.RWMutex - peerStatuses map[int]*PeerStatus - connectedAt time.Time - isConnected bool - isRegistered bool - tunnelIP string - version string - orgID string + addr string + socketPath string + listener net.Listener + server *http.Server + onConnect func(ConnectionRequest) error + onSwitchOrg func(SwitchOrgRequest) error + onDisconnect func() error + onExit func() error + statusMu sync.RWMutex + peerStatuses map[int]*PeerStatus + connectedAt time.Time + isConnected bool + isRegistered bool + tunnelIP string + version string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address func NewAPI(addr string) *API { s := &API{ - addr: addr, - connectionChan: make(chan ConnectionRequest, 1), - switchOrgChan: make(chan SwitchOrgRequest, 1), - shutdownChan: make(chan struct{}, 1), - disconnectChan: make(chan struct{}, 1), - peerStatuses: make(map[int]*PeerStatus), + addr: addr, + peerStatuses: make(map[int]*PeerStatus), } return s @@ -82,17 +86,26 @@ func NewAPI(addr string) *API { // NewAPISocket creates a new HTTP server that listens on a Unix socket or Windows named pipe func NewAPISocket(socketPath string) *API { s := &API{ - socketPath: socketPath, - connectionChan: make(chan ConnectionRequest, 1), - switchOrgChan: make(chan SwitchOrgRequest, 1), - shutdownChan: make(chan struct{}, 1), - disconnectChan: make(chan struct{}, 1), - peerStatuses: make(map[int]*PeerStatus), + socketPath: socketPath, + peerStatuses: make(map[int]*PeerStatus), } return s } +// SetHandlers sets the callback functions for handling API requests +func (s *API) SetHandlers( + onConnect func(ConnectionRequest) error, + onSwitchOrg func(SwitchOrgRequest) error, + onDisconnect func() error, + onExit func() error, +) { + s.onConnect = onConnect + s.onSwitchOrg = onSwitchOrg + s.onDisconnect = onDisconnect + s.onExit = onExit +} + // Start starts the HTTP server func (s *API) Start() error { mux := http.NewServeMux() @@ -149,26 +162,6 @@ func (s *API) Stop() error { return nil } -// GetConnectionChannel returns the channel for receiving connection requests -func (s *API) GetConnectionChannel() <-chan ConnectionRequest { - return s.connectionChan -} - -// GetSwitchOrgChannel returns the channel for receiving org switch requests -func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest { - return s.switchOrgChan -} - -// GetShutdownChannel returns the channel for receiving shutdown requests -func (s *API) GetShutdownChannel() <-chan struct{} { - return s.shutdownChan -} - -// GetDisconnectChannel returns the channel for receiving disconnect requests -func (s *API) GetDisconnectChannel() <-chan struct{} { - return s.disconnectChan -} - // UpdatePeerStatus updates the status of a peer including endpoint and relay info func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { s.statusMu.Lock() @@ -277,8 +270,13 @@ func (s *API) handleConnect(w http.ResponseWriter, r *http.Request) { return } - // Send the request to the main goroutine - s.connectionChan <- req + // Call the connect handler if set + if s.onConnect != nil { + if err := s.onConnect(req); err != nil { + http.Error(w, fmt.Sprintf("Connection failed: %v", err), http.StatusInternalServerError) + return + } + } // Return a success response w.Header().Set("Content-Type", "application/json") @@ -320,12 +318,12 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { logger.Info("Received exit request via API") - // Send shutdown signal - select { - case s.shutdownChan <- struct{}{}: - // Signal sent successfully - default: - // Channel already has a signal, don't block + // Call the exit handler if set + if s.onExit != nil { + if err := s.onExit(); err != nil { + http.Error(w, fmt.Sprintf("Exit failed: %v", err), http.StatusInternalServerError) + return + } } // Return a success response @@ -358,14 +356,12 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { logger.Info("Received org switch request to orgId: %s", req.OrgID) - // Send the request to the main goroutine - select { - case s.switchOrgChan <- req: - // Signal sent successfully - default: - // Channel already has a pending request - http.Error(w, "Org switch already in progress", http.StatusConflict) - return + // Call the switch org handler if set + if s.onSwitchOrg != nil { + if err := s.onSwitchOrg(req); err != nil { + http.Error(w, fmt.Sprintf("Org switch failed: %v", err), http.StatusInternalServerError) + return + } } // Return a success response @@ -394,12 +390,12 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { logger.Info("Received disconnect request via API") - // Send disconnect signal - select { - case s.disconnectChan <- struct{}{}: - // Signal sent successfully - default: - // Channel already has a signal, don't block + // Call the disconnect handler if set + if s.onDisconnect != nil { + if err := s.onDisconnect(); err != nil { + http.Error(w, fmt.Sprintf("Disconnect failed: %v", err), http.StatusInternalServerError) + return + } } // Return a success response diff --git a/main.go b/main.go index b07ca5a..4656636 100644 --- a/main.go +++ b/main.go @@ -205,26 +205,41 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { } // Create a new olm.Config struct and copy values from the main config - olmConfig := olm.Config{ - Endpoint: config.Endpoint, - ID: config.ID, - Secret: config.Secret, - UserToken: config.UserToken, - MTU: config.MTU, - DNS: config.DNS, - InterfaceName: config.InterfaceName, - LogLevel: config.LogLevel, - EnableAPI: config.EnableAPI, - HTTPAddr: config.HTTPAddr, - SocketPath: config.SocketPath, - Holepunch: config.Holepunch, - TlsClientCert: config.TlsClientCert, - PingIntervalDuration: config.PingIntervalDuration, - PingTimeoutDuration: config.PingTimeoutDuration, - Version: config.Version, - OrgID: config.OrgID, - // DoNotCreateNewClient: config.DoNotCreateNewClient, + olmConfig := olm.GlobalConfig{ + LogLevel: config.LogLevel, + EnableAPI: config.EnableAPI, + HTTPAddr: config.HTTPAddr, + SocketPath: config.SocketPath, + Version: config.Version, } olm.Init(ctx, olmConfig) + + if config.ID != "" && config.Secret != "" && config.Endpoint != "" { + tunnelConfig := olm.TunnelConfig{ + Endpoint: config.Endpoint, + ID: config.ID, + Secret: config.Secret, + UserToken: config.UserToken, + MTU: config.MTU, + DNS: config.DNS, + InterfaceName: config.InterfaceName, + Holepunch: config.Holepunch, + TlsClientCert: config.TlsClientCert, + PingIntervalDuration: config.PingIntervalDuration, + PingTimeoutDuration: config.PingTimeoutDuration, + OrgID: config.OrgID, + } + go olm.StartTunnel(tunnelConfig) + } else { + logger.Info("Incomplete tunnel configuration, not starting tunnel") + } + + // Wait for context cancellation (from signals or API shutdown) + <-ctx.Done() + logger.Info("Shutdown signal received, cleaning up...") + + // Clean up resources + olm.Close() + logger.Info("Shutdown complete") } diff --git a/olm/interface.go b/olm/interface.go index 9e76dc1..0e09d58 100644 --- a/olm/interface.go +++ b/olm/interface.go @@ -15,7 +15,7 @@ import ( ) // ConfigureInterface configures a network interface with an IP address and brings it up -func ConfigureInterface(interfaceName string, wgData WgData) error { +func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error { logger.Info("The tunnel IP is: %s", wgData.TunnelIP) // Parse the IP address and network @@ -32,6 +32,7 @@ func ConfigureInterface(interfaceName string, wgData WgData) error { // network.SetTunnelRemoteAddress() // what does this do? network.SetIPv4Settings([]string{destinationAddress}, []string{mask}) + network.SetMTU(mtu) apiServer.SetTunnelIP(destinationAddress) if interfaceName == "" { diff --git a/olm/olm.go b/olm/olm.go index 18ed302..9b7ab66 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -3,6 +3,7 @@ package olm import ( "context" "encoding/json" + "fmt" "net" "runtime" "time" @@ -20,7 +21,21 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type Config struct { +type GlobalConfig struct { + // Logging + LogLevel string + + // HTTP server + EnableAPI bool + HTTPAddr string + SocketPath string + Version string + + // Source tracking (not in JSON) + sources map[string]string +} + +type TunnelConfig struct { // Connection settings Endpoint string ID string @@ -32,14 +47,6 @@ type Config struct { DNS string InterfaceName string - // Logging - LogLevel string - - // HTTP server - EnableAPI bool - HTTPAddr string - SocketPath string - // Advanced Holepunch bool TlsClientCert string @@ -48,11 +55,7 @@ type Config struct { PingIntervalDuration time.Duration PingTimeoutDuration time.Duration - // Source tracking (not in JSON) - sources map[string]string - - Version string - OrgID string + OrgID string // DoNotCreateNewClient bool FileDescriptorTun uint32 @@ -74,21 +77,21 @@ var ( sharedBind *bind.SharedBind holePunchManager *holepunch.Manager peerMonitor *peermonitor.PeerMonitor + globalConfig GlobalConfig + globalCtx context.Context stopRegister func() stopPing chan struct{} ) -func Init(ctx context.Context, config Config) { +func Init(ctx context.Context, config GlobalConfig) { + globalConfig = config + globalCtx = ctx + // Create a cancellable context for internal shutdown control ctx, cancel := context.WithCancel(ctx) defer cancel() logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) - network.SetMTU(config.MTU) - - if config.Holepunch { - logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") - } if config.HTTPAddr != "" { apiServer = api.NewAPI(config.HTTPAddr) @@ -97,35 +100,15 @@ func Init(ctx context.Context, config Config) { } apiServer.SetVersion(config.Version) - apiServer.SetOrgID(config.OrgID) if err := apiServer.Start(); err != nil { logger.Fatal("Failed to start HTTP server: %v", err) } - // Listen for shutdown requests from the API - go func() { - <-apiServer.GetShutdownChannel() - logger.Info("Shutdown requested via API") - // Cancel the context to trigger graceful shutdown - cancel() - }() - - var ( - id = config.ID - secret = config.Secret - endpoint = config.Endpoint - userToken = config.UserToken - ) - - // Main event loop that handles connect, disconnect, and reconnect - for { - select { - case <-ctx.Done(): - logger.Info("Context cancelled while waiting for credentials") - goto shutdown - - case req := <-apiServer.GetConnectionChannel(): + // Set up API handlers + apiServer.SetHandlers( + // onConnect + func(req api.ConnectionRequest) error { logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) // Stop any existing tunnel before starting a new one @@ -134,67 +117,120 @@ func Init(ctx context.Context, config Config) { StopTunnel() } - // Set the connection parameters - id = req.ID - secret = req.Secret - endpoint = req.Endpoint - userToken := req.UserToken + tunnelConfig := TunnelConfig{ + Endpoint: req.Endpoint, + ID: req.ID, + Secret: req.Secret, + UserToken: req.UserToken, + MTU: req.MTU, + DNS: req.DNS, + InterfaceName: req.InterfaceName, + Holepunch: req.Holepunch, + TlsClientCert: req.TlsClientCert, + OrgID: req.OrgID, + } + + var err error + // Parse ping interval + if req.PingInterval != "" { + tunnelConfig.PingIntervalDuration, err = time.ParseDuration(req.PingInterval) + if err != nil { + logger.Warn("Invalid PING_INTERVAL value: %s, using default 3 seconds", req.PingInterval) + tunnelConfig.PingIntervalDuration = 3 * time.Second + } + } else { + tunnelConfig.PingIntervalDuration = 3 * time.Second + } + // Parse ping timeout + if req.PingTimeout != "" { + tunnelConfig.PingTimeoutDuration, err = time.ParseDuration(req.PingTimeout) + if err != nil { + logger.Warn("Invalid PING_TIMEOUT value: %s, using default 5 seconds", req.PingTimeout) + tunnelConfig.PingTimeoutDuration = 5 * time.Second + } + } else { + tunnelConfig.PingTimeoutDuration = 5 * time.Second + } + if req.MTU == 0 { + tunnelConfig.MTU = 1420 + } + if req.DNS == "" { + tunnelConfig.DNS = "9.9.9.9" + } + if req.InterfaceName == "" { + tunnelConfig.InterfaceName = "olm" + } // Start the tunnel process with the new credentials - if id != "" && secret != "" && endpoint != "" { + if tunnelConfig.ID != "" && tunnelConfig.Secret != "" && tunnelConfig.Endpoint != "" { logger.Info("Starting tunnel with new credentials") - tunnelRunning = true - go StartTunnel(ctx, config, id, secret, userToken, endpoint) + go StartTunnel(tunnelConfig) } - case <-apiServer.GetDisconnectChannel(): - logger.Info("Received disconnect request via API") + return nil + }, + // onSwitchOrg + func(req api.SwitchOrgRequest) error { + logger.Info("Processing org switch request to orgId: %s", req.OrgID) + + // Ensure we have an active olmClient + if olmClient == nil { + return fmt.Errorf("no active connection to switch organizations") + } + + // Update the orgID in the API server + apiServer.SetOrgID(req.OrgID) + + // Mark as not connected to trigger re-registration + connected = false + + Close() + + // Clear peer statuses in API + apiServer.SetRegistered(false) + apiServer.SetTunnelIP("") + + // Trigger re-registration with new orgId + logger.Info("Re-registering with new orgId: %s", req.OrgID) + publicKey := privateKey.PublicKey() + stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": true, // Default to relay mode for org switch + "olmVersion": globalConfig.Version, + "orgId": req.OrgID, + }, 1*time.Second) + + return nil + }, + // onDisconnect + func() error { + logger.Info("Processing disconnect request via API") StopTunnel() - // Clear credentials so we wait for new connect call - id = "" - secret = "" - endpoint = "" - userToken = "" - - default: - // If we have credentials and no tunnel is running, start it - if id != "" && secret != "" && endpoint != "" && !tunnelRunning { - logger.Info("Starting tunnel process with initial credentials") - tunnelRunning = true - go StartTunnel(ctx, config, id, secret, userToken, endpoint) - } else if id == "" || secret == "" || endpoint == "" { - // If we don't have credentials, check if API is enabled - if !config.EnableAPI { - missing := []string{} - if id == "" { - missing = append(missing, "id") - } - if secret == "" { - missing = append(missing, "secret") - } - if endpoint == "" { - missing = append(missing, "endpoint") - } - // exit the application because there is no way to provide the missing parameters - logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing) - goto shutdown - } - } - - // Sleep briefly to prevent tight loop - time.Sleep(100 * time.Millisecond) - } - } - -shutdown: - Close() - apiServer.Stop() - logger.Info("Olm service shutting down") + return nil + }, + // onExit + func() error { + logger.Info("Processing shutdown request via API") + cancel() + return nil + }, + ) } -func StartTunnel(ctx context.Context, config Config, id string, secret string, userToken string, endpoint string) { +func StartTunnel(config TunnelConfig) { + if tunnelRunning { + logger.Info("Tunnel already running") + return + } + + tunnelRunning = true // Also set it here in case it is called externally + + if config.Holepunch { + logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") + } + // Create a cancellable context for this tunnel process - tunnelCtx, cancel := context.WithCancel(ctx) + tunnelCtx, cancel := context.WithCancel(globalCtx) tunnelCancel = cancel defer func() { tunnelCancel = nil @@ -205,8 +241,14 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u var ( interfaceName = config.InterfaceName + id = config.ID + secret = config.Secret + endpoint = config.Endpoint + userToken = config.UserToken ) + apiServer.SetOrgID(config.OrgID) + // Create a new olm client using the provided credentials olm, err := websocket.NewClient( id, // Use provided ID @@ -431,7 +473,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u logger.Error("Failed to bring up WireGuard device: %v", err) } - if err = ConfigureInterface(interfaceName, wgData); err != nil { + if err = ConfigureInterface(interfaceName, wgData, config.MTU); err != nil { logger.Error("Failed to configure interface: %v", err) } @@ -753,7 +795,7 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "relay": !config.Holepunch, - "olmVersion": config.Version, + "olmVersion": globalConfig.Version, "orgId": config.OrgID, // "doNotCreateNewClient": config.DoNotCreateNewClient, }, 1*time.Second) @@ -777,36 +819,6 @@ func StartTunnel(ctx context.Context, config Config, id string, secret string, u } defer olm.Close() - // Listen for org switch requests from the API - go func() { - for req := range apiServer.GetSwitchOrgChannel() { - logger.Info("Processing org switch request to orgId: %s", req.OrgID) - - // Update the config with the new orgId - config.OrgID = req.OrgID - - // Mark as not connected to trigger re-registration - connected = false - - Close() - - // Clear peer statuses in API - apiServer.SetRegistered(false) - apiServer.SetTunnelIP("") - apiServer.SetOrgID(config.OrgID) - - // Trigger re-registration with new orgId - logger.Info("Re-registering with new orgId: %s", config.OrgID) - publicKey := privateKey.PublicKey() - stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": !config.Holepunch, - "olmVersion": config.Version, - "orgId": config.OrgID, - }, 1*time.Second) - } - }() - // Wait for context cancellation <-tunnelCtx.Done() logger.Info("Tunnel process context cancelled, cleaning up")