Update switching orgs

Former-commit-id: 690b133c7b
This commit is contained in:
Owen
2025-11-03 20:33:06 -08:00
parent 43b3822090
commit 38eb56381f
3 changed files with 648 additions and 0 deletions

View File

@@ -18,6 +18,11 @@ type ConnectionRequest struct {
Endpoint string `json:"endpoint"`
}
// SwitchOrgRequest defines the structure for switching organizations
type SwitchOrgRequest struct {
OrgID string `json:"orgId"`
}
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
@@ -45,6 +50,7 @@ type API struct {
listener net.Listener
server *http.Server
connectionChan chan ConnectionRequest
switchOrgChan chan SwitchOrgRequest
shutdownChan chan struct{}
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
@@ -60,6 +66,7 @@ func NewAPI(addr string) *API {
s := &API{
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -72,6 +79,7 @@ func NewAPISocket(socketPath string) *API {
s := &API{
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
switchOrgChan: make(chan SwitchOrgRequest, 1),
shutdownChan: make(chan struct{}, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -84,6 +92,7 @@ func (s *API) Start() error {
mux := http.NewServeMux()
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
mux.HandleFunc("/exit", s.handleExit)
s.server = &http.Server{
@@ -138,6 +147,11 @@ func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
return s.connectionChan
}
// GetSwitchOrgChannel returns the channel for receiving org switch requests
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
return s.switchOrgChan
}
// GetShutdownChannel returns the channel for receiving shutdown requests
func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
@@ -292,3 +306,43 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
"status": "shutdown initiated",
})
}
// handleSwitchOrg handles the /switch-org endpoint
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req SwitchOrgRequest
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
return
}
// Validate required fields
if req.OrgID == "" {
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
return
}
logger.Info("Received org switch request to orgId: %s", req.OrgID)
// Send the request to the main goroutine
select {
case s.switchOrgChan <- req:
// Signal sent successfully
default:
// Channel already has a pending request
http.Error(w, "Org switch already in progress", http.StatusConflict)
return
}
// Return a success response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(map[string]string{
"status": "org switch request accepted",
})
}

523
diff Normal file
View File

@@ -0,0 +1,523 @@
diff --git a/api/api.go b/api/api.go
index dd07751..0d2e4ef 100644
--- a/api/api.go
+++ b/api/api.go
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
Endpoint string `json:"endpoint"`
}
+// SwitchOrgRequest defines the structure for switching organizations
+type SwitchOrgRequest struct {
+ OrgID string `json:"orgId"`
+}
+
// PeerStatus represents the status of a peer connection
type PeerStatus struct {
SiteID int `json:"siteId"`
@@ -35,6 +40,7 @@ type StatusResponse struct {
Registered bool `json:"registered"`
TunnelIP string `json:"tunnelIP,omitempty"`
Version string `json:"version,omitempty"`
+ OrgID string `json:"orgId,omitempty"`
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
}
@@ -46,6 +52,7 @@ type API struct {
server *http.Server
connectionChan chan ConnectionRequest
shutdownChan chan struct{}
+ switchOrgChan chan SwitchOrgRequest
statusMu sync.RWMutex
peerStatuses map[int]*PeerStatus
connectedAt time.Time
@@ -53,6 +60,7 @@ type API struct {
isRegistered bool
tunnelIP string
version string
+ orgID string
}
// NewAPI creates a new HTTP server that listens on a TCP address
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
addr: addr,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
socketPath: socketPath,
connectionChan: make(chan ConnectionRequest, 1),
shutdownChan: make(chan struct{}, 1),
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
peerStatuses: make(map[int]*PeerStatus),
}
@@ -85,6 +95,7 @@ func (s *API) Start() error {
mux.HandleFunc("/connect", s.handleConnect)
mux.HandleFunc("/status", s.handleStatus)
mux.HandleFunc("/exit", s.handleExit)
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
s.server = &http.Server{
Handler: mux,
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
return s.shutdownChan
}
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
+ return s.switchOrgChan
+}
+
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
s.version = version
}
+// SetOrgID sets the org ID
+func (s *API) SetOrgID(orgID string) {
+ s.statusMu.Lock()
+ defer s.statusMu.Unlock()
+ s.orgID = orgID
+}
+
// UpdatePeerRelayStatus updates only the relay status of a peer
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
s.statusMu.Lock()
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
Registered: s.isRegistered,
TunnelIP: s.tunnelIP,
Version: s.version,
+ OrgID: s.orgID,
PeerStatuses: s.peerStatuses,
}
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
"status": "shutdown initiated",
})
}
+
+// handleSwitchOrg handles the /switch-org endpoint
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req SwitchOrgRequest
+ decoder := json.NewDecoder(r.Body)
+ if err := decoder.Decode(&req); err != nil {
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
+ return
+ }
+
+ // Validate required fields
+ if req.OrgID == "" {
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
+ return
+ }
+
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
+
+ // Send the request to the main goroutine
+ select {
+ case s.switchOrgChan <- req:
+ // Signal sent successfully
+ default:
+ // Channel already has a signal, don't block
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
+ return
+ }
+
+ // Return a success response
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusAccepted)
+ json.NewEncoder(w).Encode(map[string]string{
+ "status": "org switch initiated",
+ "orgId": req.OrgID,
+ })
+}
diff --git a/olm/olm.go b/olm/olm.go
index 78080c4..5e292d6 100644
--- a/olm/olm.go
+++ b/olm/olm.go
@@ -58,6 +58,58 @@ type Config struct {
OrgID string
}
+// tunnelState holds all the active tunnel resources that need cleanup
+type tunnelState struct {
+ dev *device.Device
+ tdev tun.Device
+ uapiListener net.Listener
+ peerMonitor *peermonitor.PeerMonitor
+ stopRegister func()
+ connected bool
+}
+
+// teardownTunnel cleans up all tunnel resources
+func teardownTunnel(state *tunnelState) {
+ if state == nil {
+ return
+ }
+
+ logger.Info("Tearing down tunnel...")
+
+ // Stop registration messages
+ if state.stopRegister != nil {
+ state.stopRegister()
+ state.stopRegister = nil
+ }
+
+ // Stop peer monitor
+ if state.peerMonitor != nil {
+ state.peerMonitor.Stop()
+ state.peerMonitor = nil
+ }
+
+ // Close UAPI listener
+ if state.uapiListener != nil {
+ state.uapiListener.Close()
+ state.uapiListener = nil
+ }
+
+ // Close WireGuard device
+ if state.dev != nil {
+ state.dev.Close()
+ state.dev = nil
+ }
+
+ // Close TUN device
+ if state.tdev != nil {
+ state.tdev.Close()
+ state.tdev = nil
+ }
+
+ state.connected = false
+ logger.Info("Tunnel teardown complete")
+}
+
func Run(ctx context.Context, config Config) {
// Create a cancellable context for internal shutdown control
ctx, cancel := context.WithCancel(ctx)
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
pingTimeout = config.PingTimeoutDuration
doHolepunch = config.Holepunch
privateKey wgtypes.Key
- connected bool
- dev *device.Device
wgData WgData
holePunchData HolePunchData
- uapiListener net.Listener
- tdev tun.Device
+ orgID = config.OrgID
)
+ // Tunnel state that can be torn down and recreated
+ tunnel := &tunnelState{}
+
stopHolepunch = make(chan struct{})
stopPing = make(chan struct{})
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
}
apiServer.SetVersion(config.Version)
+ apiServer.SetOrgID(orgID)
if err := apiServer.Start(); err != nil {
logger.Fatal("Failed to start HTTP server: %v", err)
}
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
logger.Debug("Received message: %v", msg.Data)
- if connected {
+ if tunnel.connected {
logger.Info("Already connected. Ignoring new connection request.")
return
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
close(stopHolepunch)
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
time.Sleep(500 * time.Millisecond)
// if there is an existing tunnel then close it
- if dev != nil {
+ if tunnel.dev != nil {
logger.Info("Got new message. Closing existing tunnel!")
- dev.Close()
+ tunnel.dev.Close()
}
jsonData, err := json.Marshal(msg.Data)
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
return
}
- tdev, err = func() (tun.Device, error) {
+ tunnel.tdev, err = func() (tun.Device, error) {
if runtime.GOOS == "darwin" {
interfaceName, err := findUnusedUTUN()
if err != nil {
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
return
}
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
interfaceName = realInterfaceName
}
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
return
}
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
go func() {
for {
- conn, err := uapiListener.Accept()
+ conn, err := tunnel.uapiListener.Accept()
if err != nil {
return
}
- go dev.IpcHandle(conn)
+ go tunnel.dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
- if err = dev.Up(); err != nil {
+ if err = tunnel.dev.Up(); err != nil {
logger.Error("Failed to bring up WireGuard device: %v", err)
}
if err = ConfigureInterface(interfaceName, wgData); err != nil {
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetTunnelIP(wgData.TunnelIP)
}
- peerMonitor = peermonitor.NewPeerMonitor(
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
if apiServer != nil {
// Find the site config to get endpoint information
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
},
fixKey(privateKey.String()),
olm,
- dev,
+ tunnel.dev,
doHolepunch,
)
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint)
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
logger.Error("Failed to configure peer: %v", err)
return
}
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
logger.Info("Configured peer %s", site.PublicKey)
}
- peerMonitor.Start()
+ tunnel.peerMonitor.Start()
if apiServer != nil {
apiServer.SetRegistered(true)
}
- connected = true
+ tunnel.connected = true
logger.Info("WireGuard device created.")
})
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
}
// Update the peer in WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Find the existing peer to get old data
var oldRemoteSubnets string
var oldPublicKey string
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
// Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
}
// Add the peer to WireGuard
- if dev != nil {
+ if tunnel.dev != nil {
// Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to add peer: %v", err)
return
}
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
}
// Remove the peer from WireGuard
- if dev != nil {
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
+ if tunnel.dev != nil {
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
logger.Error("Failed to remove peer: %v", err)
// Send error response if needed
return
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
}
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
apiServer.SetConnectionStatus(true)
}
- if connected {
+ if tunnel.connected {
logger.Debug("Already connected, skipping registration")
return nil
}
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
- "orgId": config.OrgID,
+ "orgId": orgID,
}, 1*time.Second)
go keepSendingPing(olm)
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
}
defer olm.Close()
+ // Listen for org switch requests from the API (after olm is created)
+ if apiServer != nil {
+ go func() {
+ for req := range apiServer.GetSwitchOrgChannel() {
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
+
+ // Update the orgId
+ orgID = req.OrgID
+
+ // Teardown existing tunnel
+ teardownTunnel(tunnel)
+
+ // Reset tunnel state
+ tunnel = &tunnelState{}
+
+ // Stop holepunch
+ select {
+ case <-stopHolepunch:
+ // Channel already closed
+ default:
+ close(stopHolepunch)
+ }
+ stopHolepunch = make(chan struct{})
+
+ // Clear API server state
+ apiServer.SetRegistered(false)
+ apiServer.SetTunnelIP("")
+ apiServer.SetOrgID(orgID)
+
+ // Send new registration message with updated orgId
+ publicKey := privateKey.PublicKey()
+ logger.Info("Sending registration message with new orgId: %s", orgID)
+
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
+ "publicKey": publicKey.String(),
+ "relay": !doHolepunch,
+ "olmVersion": config.Version,
+ "orgId": orgID,
+ }, 1*time.Second)
+ }
+ }()
+ }
+
select {
case <-ctx.Done():
logger.Info("Context cancelled")
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
close(stopHolepunch)
}
- if stopRegister != nil {
- stopRegister()
- stopRegister = nil
+ if tunnel.stopRegister != nil {
+ tunnel.stopRegister()
+ tunnel.stopRegister = nil
}
select {
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
close(stopPing)
}
- if peerMonitor != nil {
- peerMonitor.Stop()
- }
-
- if uapiListener != nil {
- uapiListener.Close()
- }
- if dev != nil {
- dev.Close()
- }
+ // Use teardownTunnel to clean up all tunnel resources
+ teardownTunnel(tunnel)
if apiServer != nil {
apiServer.Stop()

View File

@@ -699,6 +699,77 @@ func Run(ctx context.Context, config Config) {
olmToken = token
})
// Listen for org switch requests from the API
if apiServer != nil {
go func() {
for req := range apiServer.GetSwitchOrgChannel() {
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
// Update the config with the new orgId
config.OrgID = req.OrgID
// Mark as not connected to trigger re-registration
connected = false
// Stop registration if running
if stopRegister != nil {
stopRegister()
stopRegister = nil
}
// Stop hole punching
select {
case <-stopHolepunch:
// Already closed
default:
close(stopHolepunch)
}
stopHolepunch = make(chan struct{})
// Stop peer monitor
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor = nil
}
// Close the WireGuard device
if dev != nil {
logger.Info("Closing existing WireGuard device for org switch")
dev.Close()
dev = nil
}
// Close UAPI listener
if uapiListener != nil {
uapiListener.Close()
uapiListener = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
}
// Clear peer statuses in API
if apiServer != nil {
apiServer.SetRegistered(false)
apiServer.SetTunnelIP("")
}
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", config.OrgID)
publicKey := privateKey.PublicKey()
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": !doHolepunch,
"olmVersion": config.Version,
"orgId": config.OrgID,
}, 1*time.Second)
}
}()
}
// Connect to the WebSocket server
if err := olm.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)