diff --git a/api/api.go b/api/api.go index 969513d..83fd6f3 100644 --- a/api/api.go +++ b/api/api.go @@ -13,9 +13,10 @@ import ( // ConnectionRequest defines the structure for an incoming connection request type ConnectionRequest struct { - ID string `json:"id"` - Secret string `json:"secret"` - Endpoint string `json:"endpoint"` + ID string `json:"id"` + Secret string `json:"secret"` + Endpoint string `json:"endpoint"` + UserToken string `json:"userToken,omitempty"` } // SwitchOrgRequest defines the structure for switching organizations @@ -53,6 +54,7 @@ type API struct { connectionChan chan ConnectionRequest switchOrgChan chan SwitchOrgRequest shutdownChan chan struct{} + disconnectChan chan struct{} statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time @@ -70,6 +72,7 @@ func NewAPI(addr string) *API { connectionChan: make(chan ConnectionRequest, 1), switchOrgChan: make(chan SwitchOrgRequest, 1), shutdownChan: make(chan struct{}, 1), + disconnectChan: make(chan struct{}, 1), peerStatuses: make(map[int]*PeerStatus), } @@ -83,6 +86,7 @@ func NewAPISocket(socketPath string) *API { connectionChan: make(chan ConnectionRequest, 1), switchOrgChan: make(chan SwitchOrgRequest, 1), shutdownChan: make(chan struct{}, 1), + disconnectChan: make(chan struct{}, 1), peerStatuses: make(map[int]*PeerStatus), } @@ -95,6 +99,7 @@ func (s *API) Start() error { mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/switch-org", s.handleSwitchOrg) + mux.HandleFunc("/disconnect", s.handleDisconnect) mux.HandleFunc("/exit", s.handleExit) s.server = &http.Server{ @@ -159,6 +164,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} { return s.shutdownChan } +// GetDisconnectChannel returns the channel for receiving disconnect requests +func (s *API) GetDisconnectChannel() <-chan struct{} { + return s.disconnectChan +} + // UpdatePeerStatus updates the status of a peer including endpoint and relay info func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { s.statusMu.Lock() @@ -356,3 +366,28 @@ func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { "status": "org switch request accepted", }) } + +// handleDisconnect handles the /disconnect endpoint +func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + logger.Info("Received disconnect request via API") + + // Send disconnect signal + select { + case s.disconnectChan <- struct{}{}: + // Signal sent successfully + default: + // Channel already has a signal, don't block + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{ + "status": "disconnect initiated", + }) +} diff --git a/olm/olm.go b/olm/olm.go index bb3433a..a28f896 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -3,7 +3,6 @@ package olm import ( "context" "encoding/json" - "fmt" "net" "os" "runtime" @@ -39,10 +38,6 @@ type Config struct { HTTPAddr string SocketPath string - // Ping settings - PingInterval string - PingTimeout string - // Advanced Holepunch bool TlsClientCert string @@ -58,133 +53,175 @@ type Config struct { OrgID string } +var ( + privateKey wgtypes.Key + connected bool + dev *device.Device + wgData WgData + holePunchData HolePunchData + uapiListener net.Listener + tdev tun.Device + apiServer *api.API + olmClient *websocket.Client + tunnelCancel context.CancelFunc +) + func Run(ctx context.Context, config Config) { // Create a cancellable context for internal shutdown control ctx, cancel := context.WithCancel(ctx) defer cancel() - // Extract commonly used values from config for convenience - var ( - endpoint = config.Endpoint - id = config.ID - secret = config.Secret - mtu = config.MTU - logLevel = config.LogLevel - interfaceName = config.InterfaceName - pingInterval = config.PingIntervalDuration - pingTimeout = config.PingTimeoutDuration - doHolepunch = config.Holepunch - privateKey wgtypes.Key - connected bool - dev *device.Device - wgData WgData - holePunchData HolePunchData - uapiListener net.Listener - tdev tun.Device - ) - - stopHolepunch = make(chan struct{}) - stopPing = make(chan struct{}) - - loggerLevel := parseLogLevel(logLevel) - logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + logger.GetLogger().SetLevel(parseLogLevel(config.LogLevel)) if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil { logger.Debug("Failed to check for updates: %v", err) } - // Log startup information - logger.Debug("Olm service starting...") - logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) - - if doHolepunch { + if config.Holepunch { logger.Warn("Hole punching is enabled. This is EXPERIMENTAL and may not work in all environments.") } - var apiServer *api.API - if config.EnableAPI { - if config.HTTPAddr != "" { - apiServer = api.NewAPI(config.HTTPAddr) - } else if config.SocketPath != "" { - apiServer = api.NewAPISocket(config.SocketPath) - } - - apiServer.SetVersion(config.Version) - apiServer.SetOrgID(config.OrgID) - if err := apiServer.Start(); err != nil { - logger.Fatal("Failed to start HTTP server: %v", err) - } - - // Listen for shutdown requests from the API - go func() { - <-apiServer.GetShutdownChannel() - logger.Info("Shutdown requested via API") - // Cancel the context to trigger graceful shutdown - cancel() - }() + if config.HTTPAddr != "" { + apiServer = api.NewAPI(config.HTTPAddr) + } else if config.SocketPath != "" { + apiServer = api.NewAPISocket(config.SocketPath) } - // // Use a goroutine to handle connection requests - // go func() { - // for req := range apiServer.GetConnectionChannel() { - // logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) + apiServer.SetVersion(config.Version) + apiServer.SetOrgID(config.OrgID) - // // Set the connection parameters - // id = req.ID - // secret = req.Secret - // endpoint = req.Endpoint - // } - // }() - // } + if err := apiServer.Start(); err != nil { + logger.Fatal("Failed to start HTTP server: %v", err) + } - // Create a new olm - olm, err := websocket.NewClient( - "olm", - id, // CLI arg takes precedence - secret, // CLI arg takes precedence - endpoint, - pingInterval, - pingTimeout, + // Listen for shutdown requests from the API + go func() { + <-apiServer.GetShutdownChannel() + logger.Info("Shutdown requested via API") + // Cancel the context to trigger graceful shutdown + cancel() + }() + + var ( + id = config.ID + secret = config.Secret + endpoint = config.Endpoint ) - if err != nil { - logger.Fatal("Failed to create olm: %v", err) - } - // wait until we have a client id and secret and endpoint - waitCount := 0 - for id == "" || secret == "" || endpoint == "" { + // Main event loop that handles connect, disconnect, and reconnect + for { select { case <-ctx.Done(): logger.Info("Context cancelled while waiting for credentials") - return + goto shutdown + + case req := <-apiServer.GetConnectionChannel(): + logger.Info("Received connection request via HTTP: id=%s, endpoint=%s", req.ID, req.Endpoint) + + // Stop any existing tunnel before starting a new one + if olmClient != nil { + logger.Info("Stopping existing tunnel before starting new connection") + StopTunnel() + } + + // Set the connection parameters + id = req.ID + secret = req.Secret + endpoint = req.Endpoint + + // Start the tunnel process with the new credentials + if id != "" && secret != "" && endpoint != "" { + logger.Info("Starting tunnel with new credentials") + go TunnelProcess(ctx, config, id, secret, endpoint) + } + + case <-apiServer.GetDisconnectChannel(): + logger.Info("Received disconnect request via API") + StopTunnel() + // Clear credentials so we wait for new connect call + id = "" + secret = "" + endpoint = "" + default: - missing := []string{} - if id == "" { - missing = append(missing, "id") + // If we have credentials and no tunnel is running, start it + if id != "" && secret != "" && endpoint != "" && olmClient == nil { + logger.Info("Starting tunnel process with initial credentials") + go TunnelProcess(ctx, config, id, secret, endpoint) + } else if id == "" || secret == "" || endpoint == "" { + // If we don't have credentials, check if API is enabled + if !config.EnableAPI { + missing := []string{} + if id == "" { + missing = append(missing, "id") + } + if secret == "" { + missing = append(missing, "secret") + } + if endpoint == "" { + missing = append(missing, "endpoint") + } + // exit the application because there is no way to provide the missing parameters + logger.Fatal("Missing required parameters: %v and API is not enabled to provide them", missing) + goto shutdown + } } - if secret == "" { - missing = append(missing, "secret") - } - if endpoint == "" { - missing = append(missing, "endpoint") - } - waitCount++ - if waitCount%10 == 1 { // Log every 10 seconds instead of every second - logger.Debug("Waiting for missing parameters: %v (waiting %d seconds)", missing, waitCount) - } - time.Sleep(1 * time.Second) + + // Sleep briefly to prevent tight loop + time.Sleep(100 * time.Millisecond) } } +shutdown: + Stop() + apiServer.Stop() + logger.Info("Olm service shutting down") +} + +func TunnelProcess(ctx context.Context, config Config, id string, secret string, endpoint string) { + // Create a cancellable context for this tunnel process + tunnelCtx, cancel := context.WithCancel(ctx) + tunnelCancel = cancel + defer func() { + tunnelCancel = nil + }() + + // Recreate channels for this tunnel session + stopHolepunch = make(chan struct{}) + stopPing = make(chan struct{}) + + var ( + interfaceName = config.InterfaceName + loggerLevel = parseLogLevel(config.LogLevel) + ) + + // Create a new olm client using the provided credentials + olm, err := websocket.NewClient( + "olm", + id, // Use provided ID + secret, // Use provided secret + endpoint, // Use provided endpoint + config.PingIntervalDuration, + config.PingTimeoutDuration, + ) + if err != nil { + logger.Error("Failed to create olm: %v", err) + return + } + + // Store the client reference globally + olmClient = olm + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { - logger.Fatal("Failed to generate private key: %v", err) + logger.Error("Failed to generate private key: %v", err) + return } sourcePort, err := FindAvailableUDPPort(49152, 65535) if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - os.Exit(1) + logger.Error("Error finding available port: %v", err) + return } olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { @@ -289,12 +326,12 @@ func Run(ctx context.Context, config Config) { if err != nil { return nil, err } - return tun.CreateTUN(interfaceName, mtu) + return tun.CreateTUN(interfaceName, config.MTU) } if tunFdStr := os.Getenv(ENV_WG_TUN_FD); tunFdStr != "" { - return createTUNFromFD(tunFdStr, mtu) + return createTUNFromFD(tunFdStr, config.MTU) } - return tun.CreateTUN(interfaceName, mtu) + return tun.CreateTUN(interfaceName, config.MTU) }() if err != nil { @@ -347,27 +384,23 @@ func Run(ctx context.Context, config Config) { if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) } - if apiServer != nil { - apiServer.SetTunnelIP(wgData.TunnelIP) - } + apiServer.SetTunnelIP(wgData.TunnelIP) peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { - if apiServer != nil { - // Find the site config to get endpoint information - var endpoint string - var isRelay bool - for _, site := range wgData.Sites { - if site.SiteId == siteID { - endpoint = site.Endpoint - // TODO: We'll need to track relay status separately - // For now, assume not using relay unless we get relay data - isRelay = !doHolepunch - break - } + // Find the site config to get endpoint information + var endpoint string + var isRelay bool + for _, site := range wgData.Sites { + if site.SiteId == siteID { + endpoint = site.Endpoint + // TODO: We'll need to track relay status separately + // For now, assume not using relay unless we get relay data + isRelay = !config.Holepunch + break } - apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) } + apiServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) if connected { logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) } else { @@ -377,14 +410,12 @@ func Run(ctx context.Context, config Config) { fixKey(privateKey.String()), olm, dev, - doHolepunch, + config.Holepunch, ) for i := range wgData.Sites { site := &wgData.Sites[i] // Use a pointer to modify the struct in the slice - if apiServer != nil { - apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) - } + apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) // Format the endpoint before configuring the peer. site.Endpoint = formatEndpoint(site.Endpoint) @@ -407,9 +438,7 @@ func Run(ctx context.Context, config Config) { peerMonitor.Start() - if apiServer != nil { - apiServer.SetRegistered(true) - } + apiServer.SetRegistered(true) connected = true @@ -637,9 +666,7 @@ func Run(ctx context.Context, config Config) { } // Update HTTP server to mark this peer as using relay - if apiServer != nil { - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) - } + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) }) @@ -670,9 +697,7 @@ func Run(ctx context.Context, config Config) { olm.OnConnect(func() error { logger.Info("Websocket Connected") - if apiServer != nil { - apiServer.SetConnectionStatus(true) - } + apiServer.SetConnectionStatus(true) if connected { logger.Debug("Already connected, skipping registration") @@ -681,11 +706,11 @@ func Run(ctx context.Context, config Config) { publicKey := privateKey.PublicKey() - logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) + logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), - "relay": !doHolepunch, + "relay": !config.Holepunch, "olmVersion": config.Version, "orgId": config.OrgID, }, 1*time.Second) @@ -700,89 +725,50 @@ func Run(ctx context.Context, config Config) { olmToken = token }) - // Listen for org switch requests from the API - if apiServer != nil { - go func() { - for req := range apiServer.GetSwitchOrgChannel() { - logger.Info("Processing org switch request to orgId: %s", req.OrgID) - - // Update the config with the new orgId - config.OrgID = req.OrgID - - // Mark as not connected to trigger re-registration - connected = false - - // Stop registration if running - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - - // Stop hole punching - select { - case <-stopHolepunch: - // Already closed - default: - close(stopHolepunch) - } - stopHolepunch = make(chan struct{}) - - // Stop peer monitor - if peerMonitor != nil { - peerMonitor.Stop() - peerMonitor = nil - } - - // Close the WireGuard device - if dev != nil { - logger.Info("Closing existing WireGuard device for org switch") - dev.Close() - dev = nil - } - - // Close UAPI listener - if uapiListener != nil { - uapiListener.Close() - uapiListener = nil - } - - // Close TUN device - if tdev != nil { - tdev.Close() - tdev = nil - } - - // Clear peer statuses in API - if apiServer != nil { - apiServer.SetRegistered(false) - apiServer.SetTunnelIP("") - apiServer.SetOrgID(config.OrgID) - } - - // Trigger re-registration with new orgId - logger.Info("Re-registering with new orgId: %s", config.OrgID) - publicKey := privateKey.PublicKey() - stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": !doHolepunch, - "olmVersion": config.Version, - "orgId": config.OrgID, - }, 1*time.Second) - } - }() - } - // Connect to the WebSocket server if err := olm.Connect(); err != nil { - logger.Fatal("Failed to connect to server: %v", err) + logger.Error("Failed to connect to server: %v", err) + return } defer olm.Close() - select { - case <-ctx.Done(): - logger.Info("Context cancelled") - } + // Listen for org switch requests from the API + go func() { + for req := range apiServer.GetSwitchOrgChannel() { + logger.Info("Processing org switch request to orgId: %s", req.OrgID) + // Update the config with the new orgId + config.OrgID = req.OrgID + + // Mark as not connected to trigger re-registration + connected = false + + Stop() + + // Clear peer statuses in API + apiServer.SetRegistered(false) + apiServer.SetTunnelIP("") + apiServer.SetOrgID(config.OrgID) + + stopHolepunch = make(chan struct{}) + // Trigger re-registration with new orgId + logger.Info("Re-registering with new orgId: %s", config.OrgID) + publicKey := privateKey.PublicKey() + stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": config.Version, + "orgId": config.OrgID, + }, 1*time.Second) + } + }() + + // Wait for context cancellation + <-tunnelCtx.Done() + logger.Info("Tunnel process context cancelled, cleaning up") +} + +func Stop() { select { case <-stopHolepunch: // Channel already closed, do nothing @@ -790,11 +776,6 @@ func Run(ctx context.Context, config Config) { close(stopHolepunch) } - if stopRegister != nil { - stopRegister() - stopRegister = nil - } - select { case <-stopPing: // Channel already closed @@ -802,20 +783,60 @@ func Run(ctx context.Context, config Config) { close(stopPing) } + if stopRegister != nil { + stopRegister() + stopRegister = nil + } + if peerMonitor != nil { peerMonitor.Stop() + peerMonitor = nil } if uapiListener != nil { uapiListener.Close() + uapiListener = nil } if dev != nil { dev.Close() + dev = nil } - - if apiServer != nil { - apiServer.Stop() + // Close TUN device + if tdev != nil { + tdev.Close() + tdev = nil } logger.Info("Olm service stopped") } + +// StopTunnel stops just the tunnel process and websocket connection +// without shutting down the entire application +func StopTunnel() { + logger.Info("Stopping tunnel process") + + // Cancel the tunnel context if it exists + if tunnelCancel != nil { + tunnelCancel() + // Give it a moment to clean up + time.Sleep(200 * time.Millisecond) + } + + // Close the websocket connection + if olmClient != nil { + olmClient.Close() + olmClient = nil + } + + Stop() + + // Reset the connected state + connected = false + + // Update API server status + apiServer.SetConnectionStatus(false) + apiServer.SetRegistered(false) + apiServer.SetTunnelIP("") + + logger.Info("Tunnel process stopped") +}