From 690b133c7b442626f11078bdbab59cecc0cd0c76 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 3 Nov 2025 20:33:06 -0800 Subject: [PATCH] Update switching orgs --- api/api.go | 54 ++++++ diff | 523 +++++++++++++++++++++++++++++++++++++++++++++++++++++ olm/olm.go | 71 ++++++++ 3 files changed, 648 insertions(+) create mode 100644 diff diff --git a/api/api.go b/api/api.go index dd07751..adc613e 100644 --- a/api/api.go +++ b/api/api.go @@ -18,6 +18,11 @@ type ConnectionRequest struct { Endpoint string `json:"endpoint"` } +// SwitchOrgRequest defines the structure for switching organizations +type SwitchOrgRequest struct { + OrgID string `json:"orgId"` +} + // PeerStatus represents the status of a peer connection type PeerStatus struct { SiteID int `json:"siteId"` @@ -45,6 +50,7 @@ type API struct { listener net.Listener server *http.Server connectionChan chan ConnectionRequest + switchOrgChan chan SwitchOrgRequest shutdownChan chan struct{} statusMu sync.RWMutex peerStatuses map[int]*PeerStatus @@ -60,6 +66,7 @@ func NewAPI(addr string) *API { s := &API{ addr: addr, connectionChan: make(chan ConnectionRequest, 1), + switchOrgChan: make(chan SwitchOrgRequest, 1), shutdownChan: make(chan struct{}, 1), peerStatuses: make(map[int]*PeerStatus), } @@ -72,6 +79,7 @@ func NewAPISocket(socketPath string) *API { s := &API{ socketPath: socketPath, connectionChan: make(chan ConnectionRequest, 1), + switchOrgChan: make(chan SwitchOrgRequest, 1), shutdownChan: make(chan struct{}, 1), peerStatuses: make(map[int]*PeerStatus), } @@ -84,6 +92,7 @@ func (s *API) Start() error { mux := http.NewServeMux() mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) + mux.HandleFunc("/switch-org", s.handleSwitchOrg) mux.HandleFunc("/exit", s.handleExit) s.server = &http.Server{ @@ -138,6 +147,11 @@ func (s *API) GetConnectionChannel() <-chan ConnectionRequest { return s.connectionChan } +// GetSwitchOrgChannel returns the channel for receiving org switch requests +func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest { + return s.switchOrgChan +} + // GetShutdownChannel returns the channel for receiving shutdown requests func (s *API) GetShutdownChannel() <-chan struct{} { return s.shutdownChan @@ -292,3 +306,43 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { "status": "shutdown initiated", }) } + +// handleSwitchOrg handles the /switch-org endpoint +func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req SwitchOrgRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.OrgID == "" { + http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest) + return + } + + logger.Info("Received org switch request to orgId: %s", req.OrgID) + + // Send the request to the main goroutine + select { + case s.switchOrgChan <- req: + // Signal sent successfully + default: + // Channel already has a pending request + http.Error(w, "Org switch already in progress", http.StatusConflict) + return + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]string{ + "status": "org switch request accepted", + }) +} diff --git a/diff b/diff new file mode 100644 index 0000000..da7e62c --- /dev/null +++ b/diff @@ -0,0 +1,523 @@ +diff --git a/api/api.go b/api/api.go +index dd07751..0d2e4ef 100644 +--- a/api/api.go ++++ b/api/api.go +@@ -18,6 +18,11 @@ type ConnectionRequest struct { + Endpoint string `json:"endpoint"` + } + ++// SwitchOrgRequest defines the structure for switching organizations ++type SwitchOrgRequest struct { ++ OrgID string `json:"orgId"` ++} ++ + // PeerStatus represents the status of a peer connection + type PeerStatus struct { + SiteID int `json:"siteId"` +@@ -35,6 +40,7 @@ type StatusResponse struct { + Registered bool `json:"registered"` + TunnelIP string `json:"tunnelIP,omitempty"` + Version string `json:"version,omitempty"` ++ OrgID string `json:"orgId,omitempty"` + PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` + } + +@@ -46,6 +52,7 @@ type API struct { + server *http.Server + connectionChan chan ConnectionRequest + shutdownChan chan struct{} ++ switchOrgChan chan SwitchOrgRequest + statusMu sync.RWMutex + peerStatuses map[int]*PeerStatus + connectedAt time.Time +@@ -53,6 +60,7 @@ type API struct { + isRegistered bool + tunnelIP string + version string ++ orgID string + } + + // NewAPI creates a new HTTP server that listens on a TCP address +@@ -61,6 +69,7 @@ func NewAPI(addr string) *API { + addr: addr, + connectionChan: make(chan ConnectionRequest, 1), + shutdownChan: make(chan struct{}, 1), ++ switchOrgChan: make(chan SwitchOrgRequest, 1), + peerStatuses: make(map[int]*PeerStatus), + } + +@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API { + socketPath: socketPath, + connectionChan: make(chan ConnectionRequest, 1), + shutdownChan: make(chan struct{}, 1), ++ switchOrgChan: make(chan SwitchOrgRequest, 1), + peerStatuses: make(map[int]*PeerStatus), + } + +@@ -85,6 +95,7 @@ func (s *API) Start() error { + mux.HandleFunc("/connect", s.handleConnect) + mux.HandleFunc("/status", s.handleStatus) + mux.HandleFunc("/exit", s.handleExit) ++ mux.HandleFunc("/switch-org", s.handleSwitchOrg) + + s.server = &http.Server{ + Handler: mux, +@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} { + return s.shutdownChan + } + ++// GetSwitchOrgChannel returns the channel for receiving org switch requests ++func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest { ++ return s.switchOrgChan ++} ++ + // UpdatePeerStatus updates the status of a peer including endpoint and relay info + func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { + s.statusMu.Lock() +@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) { + s.version = version + } + ++// SetOrgID sets the org ID ++func (s *API) SetOrgID(orgID string) { ++ s.statusMu.Lock() ++ defer s.statusMu.Unlock() ++ s.orgID = orgID ++} ++ + // UpdatePeerRelayStatus updates only the relay status of a peer + func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { + s.statusMu.Lock() +@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { + Registered: s.isRegistered, + TunnelIP: s.tunnelIP, + Version: s.version, ++ OrgID: s.orgID, + PeerStatuses: s.peerStatuses, + } + +@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { + "status": "shutdown initiated", + }) + } ++ ++// handleSwitchOrg handles the /switch-org endpoint ++func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { ++ if r.Method != http.MethodPost { ++ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) ++ return ++ } ++ ++ var req SwitchOrgRequest ++ decoder := json.NewDecoder(r.Body) ++ if err := decoder.Decode(&req); err != nil { ++ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) ++ return ++ } ++ ++ // Validate required fields ++ if req.OrgID == "" { ++ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest) ++ return ++ } ++ ++ logger.Info("Received org switch request to orgId: %s", req.OrgID) ++ ++ // Send the request to the main goroutine ++ select { ++ case s.switchOrgChan <- req: ++ // Signal sent successfully ++ default: ++ // Channel already has a signal, don't block ++ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests) ++ return ++ } ++ ++ // Return a success response ++ w.Header().Set("Content-Type", "application/json") ++ w.WriteHeader(http.StatusAccepted) ++ json.NewEncoder(w).Encode(map[string]string{ ++ "status": "org switch initiated", ++ "orgId": req.OrgID, ++ }) ++} +diff --git a/olm/olm.go b/olm/olm.go +index 78080c4..5e292d6 100644 +--- a/olm/olm.go ++++ b/olm/olm.go +@@ -58,6 +58,58 @@ type Config struct { + OrgID string + } + ++// tunnelState holds all the active tunnel resources that need cleanup ++type tunnelState struct { ++ dev *device.Device ++ tdev tun.Device ++ uapiListener net.Listener ++ peerMonitor *peermonitor.PeerMonitor ++ stopRegister func() ++ connected bool ++} ++ ++// teardownTunnel cleans up all tunnel resources ++func teardownTunnel(state *tunnelState) { ++ if state == nil { ++ return ++ } ++ ++ logger.Info("Tearing down tunnel...") ++ ++ // Stop registration messages ++ if state.stopRegister != nil { ++ state.stopRegister() ++ state.stopRegister = nil ++ } ++ ++ // Stop peer monitor ++ if state.peerMonitor != nil { ++ state.peerMonitor.Stop() ++ state.peerMonitor = nil ++ } ++ ++ // Close UAPI listener ++ if state.uapiListener != nil { ++ state.uapiListener.Close() ++ state.uapiListener = nil ++ } ++ ++ // Close WireGuard device ++ if state.dev != nil { ++ state.dev.Close() ++ state.dev = nil ++ } ++ ++ // Close TUN device ++ if state.tdev != nil { ++ state.tdev.Close() ++ state.tdev = nil ++ } ++ ++ state.connected = false ++ logger.Info("Tunnel teardown complete") ++} ++ + func Run(ctx context.Context, config Config) { + // Create a cancellable context for internal shutdown control + ctx, cancel := context.WithCancel(ctx) +@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) { + pingTimeout = config.PingTimeoutDuration + doHolepunch = config.Holepunch + privateKey wgtypes.Key +- connected bool +- dev *device.Device + wgData WgData + holePunchData HolePunchData +- uapiListener net.Listener +- tdev tun.Device ++ orgID = config.OrgID + ) + ++ // Tunnel state that can be torn down and recreated ++ tunnel := &tunnelState{} ++ + stopHolepunch = make(chan struct{}) + stopPing = make(chan struct{}) + +@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) { + } + + apiServer.SetVersion(config.Version) ++ apiServer.SetOrgID(orgID) + if err := apiServer.Start(); err != nil { + logger.Fatal("Failed to start HTTP server: %v", err) + } +@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) { + olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + +- if connected { ++ if tunnel.connected { + logger.Info("Already connected. Ignoring new connection request.") + return + } + +- if stopRegister != nil { +- stopRegister() +- stopRegister = nil ++ if tunnel.stopRegister != nil { ++ tunnel.stopRegister() ++ tunnel.stopRegister = nil + } + + close(stopHolepunch) +@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) { + time.Sleep(500 * time.Millisecond) + + // if there is an existing tunnel then close it +- if dev != nil { ++ if tunnel.dev != nil { + logger.Info("Got new message. Closing existing tunnel!") +- dev.Close() ++ tunnel.dev.Close() + } + + jsonData, err := json.Marshal(msg.Data) +@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) { + return + } + +- tdev, err = func() (tun.Device, error) { ++ tunnel.tdev, err = func() (tun.Device, error) { + if runtime.GOOS == "darwin" { + interfaceName, err := findUnusedUTUN() + if err != nil { +@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) { + return + } + +- if realInterfaceName, err2 := tdev.Name(); err2 == nil { ++ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil { + interfaceName = realInterfaceName + } + +@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) { + return + } + +- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) ++ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) + +- uapiListener, err = uapiListen(interfaceName, fileUAPI) ++ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) +@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) { + + go func() { + for { +- conn, err := uapiListener.Accept() ++ conn, err := tunnel.uapiListener.Accept() + if err != nil { + return + } +- go dev.IpcHandle(conn) ++ go tunnel.dev.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + +- if err = dev.Up(); err != nil { ++ if err = tunnel.dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + if err = ConfigureInterface(interfaceName, wgData); err != nil { +@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) { + apiServer.SetTunnelIP(wgData.TunnelIP) + } + +- peerMonitor = peermonitor.NewPeerMonitor( ++ tunnel.peerMonitor = peermonitor.NewPeerMonitor( + func(siteID int, connected bool, rtt time.Duration) { + if apiServer != nil { + // Find the site config to get endpoint information +@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) { + }, + fixKey(privateKey.String()), + olm, +- dev, ++ tunnel.dev, + doHolepunch, + ) + +@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) { + // Format the endpoint before configuring the peer. + site.Endpoint = formatEndpoint(site.Endpoint) + +- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { ++ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil { + logger.Error("Failed to configure peer: %v", err) + return + } +@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) { + logger.Info("Configured peer %s", site.PublicKey) + } + +- peerMonitor.Start() ++ tunnel.peerMonitor.Start() + + if apiServer != nil { + apiServer.SetRegistered(true) + } + +- connected = true ++ tunnel.connected = true + + logger.Info("WireGuard device created.") + }) +@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) { + } + + // Update the peer in WireGuard +- if dev != nil { ++ if tunnel.dev != nil { + // Find the existing peer to get old data + var oldRemoteSubnets string + var oldPublicKey string +@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) { + // If the public key has changed, remove the old peer first + if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { + logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) +- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { ++ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil { + logger.Error("Failed to remove old peer: %v", err) + return + } +@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) { + // Format the endpoint before updating the peer. + siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) + +- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { ++ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } +@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) { + } + + // Add the peer to WireGuard +- if dev != nil { ++ if tunnel.dev != nil { + // Format the endpoint before adding the new peer. + siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) + +- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { ++ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } +@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) { + } + + // Remove the peer from WireGuard +- if dev != nil { +- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { ++ if tunnel.dev != nil { ++ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { + logger.Error("Failed to remove peer: %v", err) + // Send error response if needed + return +@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) { + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) + } + +- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) ++ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) + }) + + olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) { +@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) { + apiServer.SetConnectionStatus(true) + } + +- if connected { ++ if tunnel.connected { + logger.Debug("Already connected, skipping registration") + return nil + } +@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) { + + logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) + +- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ ++ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !doHolepunch, + "olmVersion": config.Version, +- "orgId": config.OrgID, ++ "orgId": orgID, + }, 1*time.Second) + + go keepSendingPing(olm) +@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) { + } + defer olm.Close() + ++ // Listen for org switch requests from the API (after olm is created) ++ if apiServer != nil { ++ go func() { ++ for req := range apiServer.GetSwitchOrgChannel() { ++ logger.Info("Org switch requested via API to orgId: %s", req.OrgID) ++ ++ // Update the orgId ++ orgID = req.OrgID ++ ++ // Teardown existing tunnel ++ teardownTunnel(tunnel) ++ ++ // Reset tunnel state ++ tunnel = &tunnelState{} ++ ++ // Stop holepunch ++ select { ++ case <-stopHolepunch: ++ // Channel already closed ++ default: ++ close(stopHolepunch) ++ } ++ stopHolepunch = make(chan struct{}) ++ ++ // Clear API server state ++ apiServer.SetRegistered(false) ++ apiServer.SetTunnelIP("") ++ apiServer.SetOrgID(orgID) ++ ++ // Send new registration message with updated orgId ++ publicKey := privateKey.PublicKey() ++ logger.Info("Sending registration message with new orgId: %s", orgID) ++ ++ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ ++ "publicKey": publicKey.String(), ++ "relay": !doHolepunch, ++ "olmVersion": config.Version, ++ "orgId": orgID, ++ }, 1*time.Second) ++ } ++ }() ++ } ++ + select { + case <-ctx.Done(): + logger.Info("Context cancelled") +@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) { + close(stopHolepunch) + } + +- if stopRegister != nil { +- stopRegister() +- stopRegister = nil ++ if tunnel.stopRegister != nil { ++ tunnel.stopRegister() ++ tunnel.stopRegister = nil + } + + select { +@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) { + close(stopPing) + } + +- if peerMonitor != nil { +- peerMonitor.Stop() +- } +- +- if uapiListener != nil { +- uapiListener.Close() +- } +- if dev != nil { +- dev.Close() +- } ++ // Use teardownTunnel to clean up all tunnel resources ++ teardownTunnel(tunnel) + + if apiServer != nil { + apiServer.Stop() diff --git a/olm/olm.go b/olm/olm.go index 78080c4..746f350 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -699,6 +699,77 @@ func Run(ctx context.Context, config Config) { olmToken = token }) + // Listen for org switch requests from the API + if apiServer != nil { + go func() { + for req := range apiServer.GetSwitchOrgChannel() { + logger.Info("Processing org switch request to orgId: %s", req.OrgID) + + // Update the config with the new orgId + config.OrgID = req.OrgID + + // Mark as not connected to trigger re-registration + connected = false + + // Stop registration if running + if stopRegister != nil { + stopRegister() + stopRegister = nil + } + + // Stop hole punching + select { + case <-stopHolepunch: + // Already closed + default: + close(stopHolepunch) + } + stopHolepunch = make(chan struct{}) + + // Stop peer monitor + if peerMonitor != nil { + peerMonitor.Stop() + peerMonitor = nil + } + + // Close the WireGuard device + if dev != nil { + logger.Info("Closing existing WireGuard device for org switch") + dev.Close() + dev = nil + } + + // Close UAPI listener + if uapiListener != nil { + uapiListener.Close() + uapiListener = nil + } + + // Close TUN device + if tdev != nil { + tdev.Close() + tdev = nil + } + + // Clear peer statuses in API + if apiServer != nil { + apiServer.SetRegistered(false) + apiServer.SetTunnelIP("") + } + + // Trigger re-registration with new orgId + logger.Info("Re-registering with new orgId: %s", config.OrgID) + publicKey := privateKey.PublicKey() + stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !doHolepunch, + "olmVersion": config.Version, + "orgId": config.OrgID, + }, 1*time.Second) + } + }() + } + // Connect to the WebSocket server if err := olm.Connect(); err != nil { logger.Fatal("Failed to connect to server: %v", err)