diff --git a/api/api.go b/api/api.go index dd07751..0d2e4ef 100644 --- a/api/api.go +++ b/api/api.go @@ -18,6 +18,11 @@ type ConnectionRequest struct { Endpoint string `json:"endpoint"` } +// SwitchOrgRequest defines the structure for switching organizations +type SwitchOrgRequest struct { + OrgID string `json:"orgId"` +} + // PeerStatus represents the status of a peer connection type PeerStatus struct { SiteID int `json:"siteId"` @@ -35,6 +40,7 @@ type StatusResponse struct { Registered bool `json:"registered"` TunnelIP string `json:"tunnelIP,omitempty"` Version string `json:"version,omitempty"` + OrgID string `json:"orgId,omitempty"` PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"` } @@ -46,6 +52,7 @@ type API struct { server *http.Server connectionChan chan ConnectionRequest shutdownChan chan struct{} + switchOrgChan chan SwitchOrgRequest statusMu sync.RWMutex peerStatuses map[int]*PeerStatus connectedAt time.Time @@ -53,6 +60,7 @@ type API struct { isRegistered bool tunnelIP string version string + orgID string } // NewAPI creates a new HTTP server that listens on a TCP address @@ -61,6 +69,7 @@ func NewAPI(addr string) *API { addr: addr, connectionChan: make(chan ConnectionRequest, 1), shutdownChan: make(chan struct{}, 1), + switchOrgChan: make(chan SwitchOrgRequest, 1), peerStatuses: make(map[int]*PeerStatus), } @@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API { socketPath: socketPath, connectionChan: make(chan ConnectionRequest, 1), shutdownChan: make(chan struct{}, 1), + switchOrgChan: make(chan SwitchOrgRequest, 1), peerStatuses: make(map[int]*PeerStatus), } @@ -85,6 +95,7 @@ func (s *API) Start() error { mux.HandleFunc("/connect", s.handleConnect) mux.HandleFunc("/status", s.handleStatus) mux.HandleFunc("/exit", s.handleExit) + mux.HandleFunc("/switch-org", s.handleSwitchOrg) s.server = &http.Server{ Handler: mux, @@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} { return s.shutdownChan } +// GetSwitchOrgChannel returns the channel for receiving org switch requests +func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest { + return s.switchOrgChan +} + // UpdatePeerStatus updates the status of a peer including endpoint and relay info func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) { s.statusMu.Lock() @@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) { s.version = version } +// SetOrgID sets the org ID +func (s *API) SetOrgID(orgID string) { + s.statusMu.Lock() + defer s.statusMu.Unlock() + s.orgID = orgID +} + // UpdatePeerRelayStatus updates only the relay status of a peer func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) { s.statusMu.Lock() @@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) { Registered: s.isRegistered, TunnelIP: s.tunnelIP, Version: s.version, + OrgID: s.orgID, PeerStatuses: s.peerStatuses, } @@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) { "status": "shutdown initiated", }) } + +// handleSwitchOrg handles the /switch-org endpoint +func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req SwitchOrgRequest + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest) + return + } + + // Validate required fields + if req.OrgID == "" { + http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest) + return + } + + logger.Info("Received org switch request to orgId: %s", req.OrgID) + + // Send the request to the main goroutine + select { + case s.switchOrgChan <- req: + // Signal sent successfully + default: + // Channel already has a signal, don't block + http.Error(w, "Org switch already in progress", http.StatusTooManyRequests) + return + } + + // Return a success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]string{ + "status": "org switch initiated", + "orgId": req.OrgID, + }) +} diff --git a/olm/olm.go b/olm/olm.go index 78080c4..5e292d6 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -58,6 +58,58 @@ type Config struct { OrgID string } +// tunnelState holds all the active tunnel resources that need cleanup +type tunnelState struct { + dev *device.Device + tdev tun.Device + uapiListener net.Listener + peerMonitor *peermonitor.PeerMonitor + stopRegister func() + connected bool +} + +// teardownTunnel cleans up all tunnel resources +func teardownTunnel(state *tunnelState) { + if state == nil { + return + } + + logger.Info("Tearing down tunnel...") + + // Stop registration messages + if state.stopRegister != nil { + state.stopRegister() + state.stopRegister = nil + } + + // Stop peer monitor + if state.peerMonitor != nil { + state.peerMonitor.Stop() + state.peerMonitor = nil + } + + // Close UAPI listener + if state.uapiListener != nil { + state.uapiListener.Close() + state.uapiListener = nil + } + + // Close WireGuard device + if state.dev != nil { + state.dev.Close() + state.dev = nil + } + + // Close TUN device + if state.tdev != nil { + state.tdev.Close() + state.tdev = nil + } + + state.connected = false + logger.Info("Tunnel teardown complete") +} + func Run(ctx context.Context, config Config) { // Create a cancellable context for internal shutdown control ctx, cancel := context.WithCancel(ctx) @@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) { pingTimeout = config.PingTimeoutDuration doHolepunch = config.Holepunch privateKey wgtypes.Key - connected bool - dev *device.Device wgData WgData holePunchData HolePunchData - uapiListener net.Listener - tdev tun.Device + orgID = config.OrgID ) + // Tunnel state that can be torn down and recreated + tunnel := &tunnelState{} + stopHolepunch = make(chan struct{}) stopPing = make(chan struct{}) @@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) { } apiServer.SetVersion(config.Version) + apiServer.SetOrgID(orgID) if err := apiServer.Start(); err != nil { logger.Fatal("Failed to start HTTP server: %v", err) } @@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) { olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) - if connected { + if tunnel.connected { logger.Info("Already connected. Ignoring new connection request.") return } - if stopRegister != nil { - stopRegister() - stopRegister = nil + if tunnel.stopRegister != nil { + tunnel.stopRegister() + tunnel.stopRegister = nil } close(stopHolepunch) @@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) { time.Sleep(500 * time.Millisecond) // if there is an existing tunnel then close it - if dev != nil { + if tunnel.dev != nil { logger.Info("Got new message. Closing existing tunnel!") - dev.Close() + tunnel.dev.Close() } jsonData, err := json.Marshal(msg.Data) @@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) { return } - tdev, err = func() (tun.Device, error) { + tunnel.tdev, err = func() (tun.Device, error) { if runtime.GOOS == "darwin" { interfaceName, err := findUnusedUTUN() if err != nil { @@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) { return } - if realInterfaceName, err2 := tdev.Name(); err2 == nil { + if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil { interfaceName = realInterfaceName } @@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) { return } - dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) + tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) - uapiListener, err = uapiListen(interfaceName, fileUAPI) + tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI) if err != nil { logger.Error("Failed to listen on uapi socket: %v", err) os.Exit(1) @@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) { go func() { for { - conn, err := uapiListener.Accept() + conn, err := tunnel.uapiListener.Accept() if err != nil { return } - go dev.IpcHandle(conn) + go tunnel.dev.IpcHandle(conn) } }() logger.Info("UAPI listener started") - if err = dev.Up(); err != nil { + if err = tunnel.dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) } if err = ConfigureInterface(interfaceName, wgData); err != nil { @@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) { apiServer.SetTunnelIP(wgData.TunnelIP) } - peerMonitor = peermonitor.NewPeerMonitor( + tunnel.peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { if apiServer != nil { // Find the site config to get endpoint information @@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) { }, fixKey(privateKey.String()), olm, - dev, + tunnel.dev, doHolepunch, ) @@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) { // Format the endpoint before configuring the peer. site.Endpoint = formatEndpoint(site.Endpoint) - if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { + if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err) return } @@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) { logger.Info("Configured peer %s", site.PublicKey) } - peerMonitor.Start() + tunnel.peerMonitor.Start() if apiServer != nil { apiServer.SetRegistered(true) } - connected = true + tunnel.connected = true logger.Info("WireGuard device created.") }) @@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) { } // Update the peer in WireGuard - if dev != nil { + if tunnel.dev != nil { // Find the existing peer to get old data var oldRemoteSubnets string var oldPublicKey string @@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) { // If the public key has changed, remove the old peer first if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) - if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { + if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil { logger.Error("Failed to remove old peer: %v", err) return } @@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) { // Format the endpoint before updating the peer. siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err) return } @@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) { } // Add the peer to WireGuard - if dev != nil { + if tunnel.dev != nil { // Format the endpoint before adding the new peer. siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err) return } @@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) { } // Remove the peer from WireGuard - if dev != nil { - if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { + if tunnel.dev != nil { + if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { logger.Error("Failed to remove peer: %v", err) // Send error response if needed return @@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) { apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) } - peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) + tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) }) olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) { @@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) { apiServer.SetConnectionStatus(true) } - if connected { + if tunnel.connected { logger.Debug("Already connected, skipping registration") return nil } @@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) - stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "relay": !doHolepunch, "olmVersion": config.Version, - "orgId": config.OrgID, + "orgId": orgID, }, 1*time.Second) go keepSendingPing(olm) @@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) { } defer olm.Close() + // Listen for org switch requests from the API (after olm is created) + if apiServer != nil { + go func() { + for req := range apiServer.GetSwitchOrgChannel() { + logger.Info("Org switch requested via API to orgId: %s", req.OrgID) + + // Update the orgId + orgID = req.OrgID + + // Teardown existing tunnel + teardownTunnel(tunnel) + + // Reset tunnel state + tunnel = &tunnelState{} + + // Stop holepunch + select { + case <-stopHolepunch: + // Channel already closed + default: + close(stopHolepunch) + } + stopHolepunch = make(chan struct{}) + + // Clear API server state + apiServer.SetRegistered(false) + apiServer.SetTunnelIP("") + apiServer.SetOrgID(orgID) + + // Send new registration message with updated orgId + publicKey := privateKey.PublicKey() + logger.Info("Sending registration message with new orgId: %s", orgID) + + tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !doHolepunch, + "olmVersion": config.Version, + "orgId": orgID, + }, 1*time.Second) + } + }() + } + select { case <-ctx.Done(): logger.Info("Context cancelled") @@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) { close(stopHolepunch) } - if stopRegister != nil { - stopRegister() - stopRegister = nil + if tunnel.stopRegister != nil { + tunnel.stopRegister() + tunnel.stopRegister = nil } select { @@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) { close(stopPing) } - if peerMonitor != nil { - peerMonitor.Stop() - } - - if uapiListener != nil { - uapiListener.Close() - } - if dev != nil { - dev.Close() - } + // Use teardownTunnel to clean up all tunnel resources + teardownTunnel(tunnel) if apiServer != nil { apiServer.Stop()