From 2718d1582561276581b4b1a9c9a8a2d229e0e161 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 24 Nov 2025 17:36:44 -0500 Subject: [PATCH] Add new api calls and onterminate Former-commit-id: 96143e4b38589fc1cef746c32bc3a127b45e7435 --- api/api.go | 11 +++++ olm/olm.go | 126 +++++++++++++++++++-------------------------------- olm/types.go | 49 ++++++++++++++++++++ 3 files changed, 107 insertions(+), 79 deletions(-) diff --git a/api/api.go b/api/api.go index 2316373..7fe8898 100644 --- a/api/api.go +++ b/api/api.go @@ -415,3 +415,14 @@ func (s *API) handleDisconnect(w http.ResponseWriter, r *http.Request) { "status": "disconnect initiated", }) } + +func (s *API) GetStatus() StatusResponse { + return StatusResponse{ + Connected: s.isConnected, + Registered: s.isRegistered, + Version: s.version, + OrgID: s.orgID, + PeerStatuses: s.peerStatuses, + NetworkSettings: network.GetSettings(), + } +} diff --git a/olm/olm.go b/olm/olm.go index 65ec9c1..1544c86 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -25,52 +25,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type GlobalConfig struct { - // Logging - LogLevel string - - // HTTP server - EnableAPI bool - HTTPAddr string - SocketPath string - Version string - - // Callbacks - OnRegistered func() - OnConnected func() - - // Source tracking (not in JSON) - sources map[string]string -} - -type TunnelConfig struct { - // Connection settings - Endpoint string - ID string - Secret string - UserToken string - - // Network settings - MTU int - DNS string - UpstreamDNS []string - InterfaceName string - - // Advanced - Holepunch bool - TlsClientCert string - - // Parsed values (not in JSON) - PingIntervalDuration time.Duration - PingTimeoutDuration time.Duration - - OrgID string - // DoNotCreateNewClient bool - - FileDescriptorTun uint32 - FileDescriptorUAPI uint32 -} - var ( privateKey wgtypes.Key connected bool @@ -184,41 +138,13 @@ func Init(ctx context.Context, config GlobalConfig) { }, // onSwitchOrg func(req api.SwitchOrgRequest) error { - logger.Info("Processing org switch request to orgId: %s", req.OrgID) - - // Ensure we have an active olmClient - if olmClient == nil { - return fmt.Errorf("no active connection to switch organizations") - } - - // Update the orgID in the API server - apiServer.SetOrgID(req.OrgID) - - // Mark as not connected to trigger re-registration - connected = false - - Close() - - // Clear peer statuses in API - apiServer.SetRegistered(false) - - // Trigger re-registration with new orgId - logger.Info("Re-registering with new orgId: %s", req.OrgID) - publicKey := privateKey.PublicKey() - stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": true, // Default to relay mode for org switch - "olmVersion": globalConfig.Version, - "orgId": req.OrgID, - }, 1*time.Second) - - return nil + logger.Info("Received switch organization request via HTTP: orgID=%s", req.OrgID) + return SwitchOrg(req.OrgID) }, // onDisconnect func() error { logger.Info("Processing disconnect request via API") - StopTunnel() - return nil + return StopTunnel() }, // onExit func() error { @@ -1020,7 +946,11 @@ func StartTunnel(config TunnelConfig) { olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") - olm.Close() + Close() + + if globalConfig.OnTerminated != nil { + go globalConfig.OnTerminated() + } }) olm.OnConnect(func() error { @@ -1155,7 +1085,7 @@ func Close() { // StopTunnel stops just the tunnel process and websocket connection // without shutting down the entire application -func StopTunnel() { +func StopTunnel() error { logger.Info("Stopping tunnel process") // Cancel the tunnel context if it exists @@ -1189,6 +1119,8 @@ func StopTunnel() { network.ClearNetworkSettings() logger.Info("Tunnel process stopped") + + return nil } func StopApi() error { @@ -1210,3 +1142,39 @@ func StartApi() error { } return nil } + +func GetStatus() api.StatusResponse { + return apiServer.GetStatus() +} + +func SwitchOrg(orgID string) error { + logger.Info("Processing org switch request to orgId: %s", orgID) + + // Ensure we have an active olmClient + if olmClient == nil { + return fmt.Errorf("no active connection to switch organizations") + } + + // Update the orgID in the API server + apiServer.SetOrgID(orgID) + + // Mark as not connected to trigger re-registration + connected = false + + Close() + + // Clear peer statuses in API + apiServer.SetRegistered(false) + + // Trigger re-registration with new orgId + logger.Info("Re-registering with new orgId: %s", orgID) + publicKey := privateKey.PublicKey() + stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": true, // Default to relay mode for org switch + "olmVersion": globalConfig.Version, + "orgId": orgID, + }, 1*time.Second) + + return nil +} diff --git a/olm/types.go b/olm/types.go index 96f63b9..92081ad 100644 --- a/olm/types.go +++ b/olm/types.go @@ -1,5 +1,7 @@ package olm +import "time" + type WgData struct { Sites []SiteConfig `json:"sites"` TunnelIP string `json:"tunnelIP"` @@ -75,3 +77,50 @@ type UpdateRemoteSubnetsData struct { OldRemoteSubnets []string `json:"oldRemoteSubnets"` // old list of remote subnets NewRemoteSubnets []string `json:"newRemoteSubnets"` // new list of remote subnets } + +type GlobalConfig struct { + // Logging + LogLevel string + + // HTTP server + EnableAPI bool + HTTPAddr string + SocketPath string + Version string + + // Callbacks + OnRegistered func() + OnConnected func() + OnTerminated func() + + // Source tracking (not in JSON) + sources map[string]string +} + +type TunnelConfig struct { + // Connection settings + Endpoint string + ID string + Secret string + UserToken string + + // Network settings + MTU int + DNS string + UpstreamDNS []string + InterfaceName string + + // Advanced + Holepunch bool + TlsClientCert string + + // Parsed values (not in JSON) + PingIntervalDuration time.Duration + PingTimeoutDuration time.Duration + + OrgID string + // DoNotCreateNewClient bool + + FileDescriptorTun uint32 + FileDescriptorUAPI uint32 +}