mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Update switching orgs
This commit is contained in:
54
api/api.go
54
api/api.go
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// SwitchOrgRequest defines the structure for switching organizations
|
||||
type SwitchOrgRequest struct {
|
||||
OrgID string `json:"orgId"`
|
||||
}
|
||||
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
@@ -45,6 +50,7 @@ type API struct {
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
switchOrgChan chan SwitchOrgRequest
|
||||
shutdownChan chan struct{}
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
@@ -60,6 +66,7 @@ func NewAPI(addr string) *API {
|
||||
s := &API{
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
@@ -72,6 +79,7 @@ func NewAPISocket(socketPath string) *API {
|
||||
s := &API{
|
||||
socketPath: socketPath,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
@@ -84,6 +92,7 @@ func (s *API) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
|
||||
s.server = &http.Server{
|
||||
@@ -138,6 +147,11 @@ func (s *API) GetConnectionChannel() <-chan ConnectionRequest {
|
||||
return s.connectionChan
|
||||
}
|
||||
|
||||
// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||
func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||
return s.switchOrgChan
|
||||
}
|
||||
|
||||
// GetShutdownChannel returns the channel for receiving shutdown requests
|
||||
func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||
return s.shutdownChan
|
||||
@@ -292,3 +306,43 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
}
|
||||
|
||||
// handleSwitchOrg handles the /switch-org endpoint
|
||||
func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req SwitchOrgRequest
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
if err := decoder.Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.OrgID == "" {
|
||||
http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Send the request to the main goroutine
|
||||
select {
|
||||
case s.switchOrgChan <- req:
|
||||
// Signal sent successfully
|
||||
default:
|
||||
// Channel already has a pending request
|
||||
http.Error(w, "Org switch already in progress", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a success response
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "org switch request accepted",
|
||||
})
|
||||
}
|
||||
|
||||
523
diff
Normal file
523
diff
Normal file
@@ -0,0 +1,523 @@
|
||||
diff --git a/api/api.go b/api/api.go
|
||||
index dd07751..0d2e4ef 100644
|
||||
--- a/api/api.go
|
||||
+++ b/api/api.go
|
||||
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
+// SwitchOrgRequest defines the structure for switching organizations
|
||||
+type SwitchOrgRequest struct {
|
||||
+ OrgID string `json:"orgId"`
|
||||
+}
|
||||
+
|
||||
// PeerStatus represents the status of a peer connection
|
||||
type PeerStatus struct {
|
||||
SiteID int `json:"siteId"`
|
||||
@@ -35,6 +40,7 @@ type StatusResponse struct {
|
||||
Registered bool `json:"registered"`
|
||||
TunnelIP string `json:"tunnelIP,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
+ OrgID string `json:"orgId,omitempty"`
|
||||
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
||||
}
|
||||
|
||||
@@ -46,6 +52,7 @@ type API struct {
|
||||
server *http.Server
|
||||
connectionChan chan ConnectionRequest
|
||||
shutdownChan chan struct{}
|
||||
+ switchOrgChan chan SwitchOrgRequest
|
||||
statusMu sync.RWMutex
|
||||
peerStatuses map[int]*PeerStatus
|
||||
connectedAt time.Time
|
||||
@@ -53,6 +60,7 @@ type API struct {
|
||||
isRegistered bool
|
||||
tunnelIP string
|
||||
version string
|
||||
+ orgID string
|
||||
}
|
||||
|
||||
// NewAPI creates a new HTTP server that listens on a TCP address
|
||||
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
|
||||
addr: addr,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
|
||||
socketPath: socketPath,
|
||||
connectionChan: make(chan ConnectionRequest, 1),
|
||||
shutdownChan: make(chan struct{}, 1),
|
||||
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
||||
peerStatuses: make(map[int]*PeerStatus),
|
||||
}
|
||||
|
||||
@@ -85,6 +95,7 @@ func (s *API) Start() error {
|
||||
mux.HandleFunc("/connect", s.handleConnect)
|
||||
mux.HandleFunc("/status", s.handleStatus)
|
||||
mux.HandleFunc("/exit", s.handleExit)
|
||||
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
||||
|
||||
s.server = &http.Server{
|
||||
Handler: mux,
|
||||
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
||||
return s.shutdownChan
|
||||
}
|
||||
|
||||
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
||||
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
||||
+ return s.switchOrgChan
|
||||
+}
|
||||
+
|
||||
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
||||
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
|
||||
s.version = version
|
||||
}
|
||||
|
||||
+// SetOrgID sets the org ID
|
||||
+func (s *API) SetOrgID(orgID string) {
|
||||
+ s.statusMu.Lock()
|
||||
+ defer s.statusMu.Unlock()
|
||||
+ s.orgID = orgID
|
||||
+}
|
||||
+
|
||||
// UpdatePeerRelayStatus updates only the relay status of a peer
|
||||
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
||||
s.statusMu.Lock()
|
||||
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
Registered: s.isRegistered,
|
||||
TunnelIP: s.tunnelIP,
|
||||
Version: s.version,
|
||||
+ OrgID: s.orgID,
|
||||
PeerStatuses: s.peerStatuses,
|
||||
}
|
||||
|
||||
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
||||
"status": "shutdown initiated",
|
||||
})
|
||||
}
|
||||
+
|
||||
+// handleSwitchOrg handles the /switch-org endpoint
|
||||
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
||||
+ if r.Method != http.MethodPost {
|
||||
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ var req SwitchOrgRequest
|
||||
+ decoder := json.NewDecoder(r.Body)
|
||||
+ if err := decoder.Decode(&req); err != nil {
|
||||
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ // Validate required fields
|
||||
+ if req.OrgID == "" {
|
||||
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
||||
+
|
||||
+ // Send the request to the main goroutine
|
||||
+ select {
|
||||
+ case s.switchOrgChan <- req:
|
||||
+ // Signal sent successfully
|
||||
+ default:
|
||||
+ // Channel already has a signal, don't block
|
||||
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ // Return a success response
|
||||
+ w.Header().Set("Content-Type", "application/json")
|
||||
+ w.WriteHeader(http.StatusAccepted)
|
||||
+ json.NewEncoder(w).Encode(map[string]string{
|
||||
+ "status": "org switch initiated",
|
||||
+ "orgId": req.OrgID,
|
||||
+ })
|
||||
+}
|
||||
diff --git a/olm/olm.go b/olm/olm.go
|
||||
index 78080c4..5e292d6 100644
|
||||
--- a/olm/olm.go
|
||||
+++ b/olm/olm.go
|
||||
@@ -58,6 +58,58 @@ type Config struct {
|
||||
OrgID string
|
||||
}
|
||||
|
||||
+// tunnelState holds all the active tunnel resources that need cleanup
|
||||
+type tunnelState struct {
|
||||
+ dev *device.Device
|
||||
+ tdev tun.Device
|
||||
+ uapiListener net.Listener
|
||||
+ peerMonitor *peermonitor.PeerMonitor
|
||||
+ stopRegister func()
|
||||
+ connected bool
|
||||
+}
|
||||
+
|
||||
+// teardownTunnel cleans up all tunnel resources
|
||||
+func teardownTunnel(state *tunnelState) {
|
||||
+ if state == nil {
|
||||
+ return
|
||||
+ }
|
||||
+
|
||||
+ logger.Info("Tearing down tunnel...")
|
||||
+
|
||||
+ // Stop registration messages
|
||||
+ if state.stopRegister != nil {
|
||||
+ state.stopRegister()
|
||||
+ state.stopRegister = nil
|
||||
+ }
|
||||
+
|
||||
+ // Stop peer monitor
|
||||
+ if state.peerMonitor != nil {
|
||||
+ state.peerMonitor.Stop()
|
||||
+ state.peerMonitor = nil
|
||||
+ }
|
||||
+
|
||||
+ // Close UAPI listener
|
||||
+ if state.uapiListener != nil {
|
||||
+ state.uapiListener.Close()
|
||||
+ state.uapiListener = nil
|
||||
+ }
|
||||
+
|
||||
+ // Close WireGuard device
|
||||
+ if state.dev != nil {
|
||||
+ state.dev.Close()
|
||||
+ state.dev = nil
|
||||
+ }
|
||||
+
|
||||
+ // Close TUN device
|
||||
+ if state.tdev != nil {
|
||||
+ state.tdev.Close()
|
||||
+ state.tdev = nil
|
||||
+ }
|
||||
+
|
||||
+ state.connected = false
|
||||
+ logger.Info("Tunnel teardown complete")
|
||||
+}
|
||||
+
|
||||
func Run(ctx context.Context, config Config) {
|
||||
// Create a cancellable context for internal shutdown control
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
|
||||
pingTimeout = config.PingTimeoutDuration
|
||||
doHolepunch = config.Holepunch
|
||||
privateKey wgtypes.Key
|
||||
- connected bool
|
||||
- dev *device.Device
|
||||
wgData WgData
|
||||
holePunchData HolePunchData
|
||||
- uapiListener net.Listener
|
||||
- tdev tun.Device
|
||||
+ orgID = config.OrgID
|
||||
)
|
||||
|
||||
+ // Tunnel state that can be torn down and recreated
|
||||
+ tunnel := &tunnelState{}
|
||||
+
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
apiServer.SetVersion(config.Version)
|
||||
+ apiServer.SetOrgID(orgID)
|
||||
if err := apiServer.Start(); err != nil {
|
||||
logger.Fatal("Failed to start HTTP server: %v", err)
|
||||
}
|
||||
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
- if connected {
|
||||
+ if tunnel.connected {
|
||||
logger.Info("Already connected. Ignoring new connection request.")
|
||||
return
|
||||
}
|
||||
|
||||
- if stopRegister != nil {
|
||||
- stopRegister()
|
||||
- stopRegister = nil
|
||||
+ if tunnel.stopRegister != nil {
|
||||
+ tunnel.stopRegister()
|
||||
+ tunnel.stopRegister = nil
|
||||
}
|
||||
|
||||
close(stopHolepunch)
|
||||
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
- if dev != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
logger.Info("Got new message. Closing existing tunnel!")
|
||||
- dev.Close()
|
||||
+ tunnel.dev.Close()
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
|
||||
return
|
||||
}
|
||||
|
||||
- tdev, err = func() (tun.Device, error) {
|
||||
+ tunnel.tdev, err = func() (tun.Device, error) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
interfaceName, err := findUnusedUTUN()
|
||||
if err != nil {
|
||||
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
|
||||
return
|
||||
}
|
||||
|
||||
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
|
||||
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
|
||||
return
|
||||
}
|
||||
|
||||
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||
|
||||
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
if err != nil {
|
||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||
os.Exit(1)
|
||||
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
|
||||
|
||||
go func() {
|
||||
for {
|
||||
- conn, err := uapiListener.Accept()
|
||||
+ conn, err := tunnel.uapiListener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
- go dev.IpcHandle(conn)
|
||||
+ go tunnel.dev.IpcHandle(conn)
|
||||
}
|
||||
}()
|
||||
logger.Info("UAPI listener started")
|
||||
|
||||
- if err = dev.Up(); err != nil {
|
||||
+ if err = tunnel.dev.Up(); err != nil {
|
||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||
}
|
||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
|
||||
apiServer.SetTunnelIP(wgData.TunnelIP)
|
||||
}
|
||||
|
||||
- peerMonitor = peermonitor.NewPeerMonitor(
|
||||
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
|
||||
func(siteID int, connected bool, rtt time.Duration) {
|
||||
if apiServer != nil {
|
||||
// Find the site config to get endpoint information
|
||||
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
|
||||
},
|
||||
fixKey(privateKey.String()),
|
||||
olm,
|
||||
- dev,
|
||||
+ tunnel.dev,
|
||||
doHolepunch,
|
||||
)
|
||||
|
||||
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
|
||||
// Format the endpoint before configuring the peer.
|
||||
site.Endpoint = formatEndpoint(site.Endpoint)
|
||||
|
||||
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
||||
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to configure peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
|
||||
logger.Info("Configured peer %s", site.PublicKey)
|
||||
}
|
||||
|
||||
- peerMonitor.Start()
|
||||
+ tunnel.peerMonitor.Start()
|
||||
|
||||
if apiServer != nil {
|
||||
apiServer.SetRegistered(true)
|
||||
}
|
||||
|
||||
- connected = true
|
||||
+ tunnel.connected = true
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
})
|
||||
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
// Update the peer in WireGuard
|
||||
- if dev != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
// Find the existing peer to get old data
|
||||
var oldRemoteSubnets string
|
||||
var oldPublicKey string
|
||||
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
|
||||
// If the public key has changed, remove the old peer first
|
||||
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
||||
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
||||
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
|
||||
logger.Error("Failed to remove old peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
|
||||
// Format the endpoint before updating the peer.
|
||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
|
||||
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to update peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
// Add the peer to WireGuard
|
||||
- if dev != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
// Format the endpoint before adding the new peer.
|
||||
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
||||
|
||||
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
|
||||
// Remove the peer from WireGuard
|
||||
- if dev != nil {
|
||||
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||
+ if tunnel.dev != nil {
|
||||
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
||||
logger.Error("Failed to remove peer: %v", err)
|
||||
// Send error response if needed
|
||||
return
|
||||
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
|
||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||
}
|
||||
|
||||
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
||||
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
|
||||
apiServer.SetConnectionStatus(true)
|
||||
}
|
||||
|
||||
- if connected {
|
||||
+ if tunnel.connected {
|
||||
logger.Debug("Already connected, skipping registration")
|
||||
return nil
|
||||
}
|
||||
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
|
||||
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
||||
|
||||
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !doHolepunch,
|
||||
"olmVersion": config.Version,
|
||||
- "orgId": config.OrgID,
|
||||
+ "orgId": orgID,
|
||||
}, 1*time.Second)
|
||||
|
||||
go keepSendingPing(olm)
|
||||
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
|
||||
}
|
||||
defer olm.Close()
|
||||
|
||||
+ // Listen for org switch requests from the API (after olm is created)
|
||||
+ if apiServer != nil {
|
||||
+ go func() {
|
||||
+ for req := range apiServer.GetSwitchOrgChannel() {
|
||||
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
|
||||
+
|
||||
+ // Update the orgId
|
||||
+ orgID = req.OrgID
|
||||
+
|
||||
+ // Teardown existing tunnel
|
||||
+ teardownTunnel(tunnel)
|
||||
+
|
||||
+ // Reset tunnel state
|
||||
+ tunnel = &tunnelState{}
|
||||
+
|
||||
+ // Stop holepunch
|
||||
+ select {
|
||||
+ case <-stopHolepunch:
|
||||
+ // Channel already closed
|
||||
+ default:
|
||||
+ close(stopHolepunch)
|
||||
+ }
|
||||
+ stopHolepunch = make(chan struct{})
|
||||
+
|
||||
+ // Clear API server state
|
||||
+ apiServer.SetRegistered(false)
|
||||
+ apiServer.SetTunnelIP("")
|
||||
+ apiServer.SetOrgID(orgID)
|
||||
+
|
||||
+ // Send new registration message with updated orgId
|
||||
+ publicKey := privateKey.PublicKey()
|
||||
+ logger.Info("Sending registration message with new orgId: %s", orgID)
|
||||
+
|
||||
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
+ "publicKey": publicKey.String(),
|
||||
+ "relay": !doHolepunch,
|
||||
+ "olmVersion": config.Version,
|
||||
+ "orgId": orgID,
|
||||
+ }, 1*time.Second)
|
||||
+ }
|
||||
+ }()
|
||||
+ }
|
||||
+
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("Context cancelled")
|
||||
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
|
||||
close(stopHolepunch)
|
||||
}
|
||||
|
||||
- if stopRegister != nil {
|
||||
- stopRegister()
|
||||
- stopRegister = nil
|
||||
+ if tunnel.stopRegister != nil {
|
||||
+ tunnel.stopRegister()
|
||||
+ tunnel.stopRegister = nil
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
|
||||
close(stopPing)
|
||||
}
|
||||
|
||||
- if peerMonitor != nil {
|
||||
- peerMonitor.Stop()
|
||||
- }
|
||||
-
|
||||
- if uapiListener != nil {
|
||||
- uapiListener.Close()
|
||||
- }
|
||||
- if dev != nil {
|
||||
- dev.Close()
|
||||
- }
|
||||
+ // Use teardownTunnel to clean up all tunnel resources
|
||||
+ teardownTunnel(tunnel)
|
||||
|
||||
if apiServer != nil {
|
||||
apiServer.Stop()
|
||||
71
olm/olm.go
71
olm/olm.go
@@ -699,6 +699,77 @@ func Run(ctx context.Context, config Config) {
|
||||
olmToken = token
|
||||
})
|
||||
|
||||
// Listen for org switch requests from the API
|
||||
if apiServer != nil {
|
||||
go func() {
|
||||
for req := range apiServer.GetSwitchOrgChannel() {
|
||||
logger.Info("Processing org switch request to orgId: %s", req.OrgID)
|
||||
|
||||
// Update the config with the new orgId
|
||||
config.OrgID = req.OrgID
|
||||
|
||||
// Mark as not connected to trigger re-registration
|
||||
connected = false
|
||||
|
||||
// Stop registration if running
|
||||
if stopRegister != nil {
|
||||
stopRegister()
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
// Stop hole punching
|
||||
select {
|
||||
case <-stopHolepunch:
|
||||
// Already closed
|
||||
default:
|
||||
close(stopHolepunch)
|
||||
}
|
||||
stopHolepunch = make(chan struct{})
|
||||
|
||||
// Stop peer monitor
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.Stop()
|
||||
peerMonitor = nil
|
||||
}
|
||||
|
||||
// Close the WireGuard device
|
||||
if dev != nil {
|
||||
logger.Info("Closing existing WireGuard device for org switch")
|
||||
dev.Close()
|
||||
dev = nil
|
||||
}
|
||||
|
||||
// Close UAPI listener
|
||||
if uapiListener != nil {
|
||||
uapiListener.Close()
|
||||
uapiListener = nil
|
||||
}
|
||||
|
||||
// Close TUN device
|
||||
if tdev != nil {
|
||||
tdev.Close()
|
||||
tdev = nil
|
||||
}
|
||||
|
||||
// Clear peer statuses in API
|
||||
if apiServer != nil {
|
||||
apiServer.SetRegistered(false)
|
||||
apiServer.SetTunnelIP("")
|
||||
}
|
||||
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", config.OrgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !doHolepunch,
|
||||
"olmVersion": config.Version,
|
||||
"orgId": config.OrgID,
|
||||
}, 1*time.Second)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Connect to the WebSocket server
|
||||
if err := olm.Connect(); err != nil {
|
||||
logger.Fatal("Failed to connect to server: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user