mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 14:06:41 +00:00
524 lines
16 KiB
Plaintext
524 lines
16 KiB
Plaintext
diff --git a/api/api.go b/api/api.go
|
|
index dd07751..0d2e4ef 100644
|
|
--- a/api/api.go
|
|
+++ b/api/api.go
|
|
@@ -18,6 +18,11 @@ type ConnectionRequest struct {
|
|
Endpoint string `json:"endpoint"`
|
|
}
|
|
|
|
+// SwitchOrgRequest defines the structure for switching organizations
|
|
+type SwitchOrgRequest struct {
|
|
+ OrgID string `json:"orgId"`
|
|
+}
|
|
+
|
|
// PeerStatus represents the status of a peer connection
|
|
type PeerStatus struct {
|
|
SiteID int `json:"siteId"`
|
|
@@ -35,6 +40,7 @@ type StatusResponse struct {
|
|
Registered bool `json:"registered"`
|
|
TunnelIP string `json:"tunnelIP,omitempty"`
|
|
Version string `json:"version,omitempty"`
|
|
+ OrgID string `json:"orgId,omitempty"`
|
|
PeerStatuses map[int]*PeerStatus `json:"peers,omitempty"`
|
|
}
|
|
|
|
@@ -46,6 +52,7 @@ type API struct {
|
|
server *http.Server
|
|
connectionChan chan ConnectionRequest
|
|
shutdownChan chan struct{}
|
|
+ switchOrgChan chan SwitchOrgRequest
|
|
statusMu sync.RWMutex
|
|
peerStatuses map[int]*PeerStatus
|
|
connectedAt time.Time
|
|
@@ -53,6 +60,7 @@ type API struct {
|
|
isRegistered bool
|
|
tunnelIP string
|
|
version string
|
|
+ orgID string
|
|
}
|
|
|
|
// NewAPI creates a new HTTP server that listens on a TCP address
|
|
@@ -61,6 +69,7 @@ func NewAPI(addr string) *API {
|
|
addr: addr,
|
|
connectionChan: make(chan ConnectionRequest, 1),
|
|
shutdownChan: make(chan struct{}, 1),
|
|
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
|
peerStatuses: make(map[int]*PeerStatus),
|
|
}
|
|
|
|
@@ -73,6 +82,7 @@ func NewAPISocket(socketPath string) *API {
|
|
socketPath: socketPath,
|
|
connectionChan: make(chan ConnectionRequest, 1),
|
|
shutdownChan: make(chan struct{}, 1),
|
|
+ switchOrgChan: make(chan SwitchOrgRequest, 1),
|
|
peerStatuses: make(map[int]*PeerStatus),
|
|
}
|
|
|
|
@@ -85,6 +95,7 @@ func (s *API) Start() error {
|
|
mux.HandleFunc("/connect", s.handleConnect)
|
|
mux.HandleFunc("/status", s.handleStatus)
|
|
mux.HandleFunc("/exit", s.handleExit)
|
|
+ mux.HandleFunc("/switch-org", s.handleSwitchOrg)
|
|
|
|
s.server = &http.Server{
|
|
Handler: mux,
|
|
@@ -143,6 +154,11 @@ func (s *API) GetShutdownChannel() <-chan struct{} {
|
|
return s.shutdownChan
|
|
}
|
|
|
|
+// GetSwitchOrgChannel returns the channel for receiving org switch requests
|
|
+func (s *API) GetSwitchOrgChannel() <-chan SwitchOrgRequest {
|
|
+ return s.switchOrgChan
|
|
+}
|
|
+
|
|
// UpdatePeerStatus updates the status of a peer including endpoint and relay info
|
|
func (s *API) UpdatePeerStatus(siteID int, connected bool, rtt time.Duration, endpoint string, isRelay bool) {
|
|
s.statusMu.Lock()
|
|
@@ -198,6 +214,13 @@ func (s *API) SetVersion(version string) {
|
|
s.version = version
|
|
}
|
|
|
|
+// SetOrgID sets the org ID
|
|
+func (s *API) SetOrgID(orgID string) {
|
|
+ s.statusMu.Lock()
|
|
+ defer s.statusMu.Unlock()
|
|
+ s.orgID = orgID
|
|
+}
|
|
+
|
|
// UpdatePeerRelayStatus updates only the relay status of a peer
|
|
func (s *API) UpdatePeerRelayStatus(siteID int, endpoint string, isRelay bool) {
|
|
s.statusMu.Lock()
|
|
@@ -261,6 +284,7 @@ func (s *API) handleStatus(w http.ResponseWriter, r *http.Request) {
|
|
Registered: s.isRegistered,
|
|
TunnelIP: s.tunnelIP,
|
|
Version: s.version,
|
|
+ OrgID: s.orgID,
|
|
PeerStatuses: s.peerStatuses,
|
|
}
|
|
|
|
@@ -292,3 +316,44 @@ func (s *API) handleExit(w http.ResponseWriter, r *http.Request) {
|
|
"status": "shutdown initiated",
|
|
})
|
|
}
|
|
+
|
|
+// handleSwitchOrg handles the /switch-org endpoint
|
|
+func (s *API) handleSwitchOrg(w http.ResponseWriter, r *http.Request) {
|
|
+ if r.Method != http.MethodPost {
|
|
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
+ return
|
|
+ }
|
|
+
|
|
+ var req SwitchOrgRequest
|
|
+ decoder := json.NewDecoder(r.Body)
|
|
+ if err := decoder.Decode(&req); err != nil {
|
|
+ http.Error(w, fmt.Sprintf("Invalid request: %v", err), http.StatusBadRequest)
|
|
+ return
|
|
+ }
|
|
+
|
|
+ // Validate required fields
|
|
+ if req.OrgID == "" {
|
|
+ http.Error(w, "Missing required field: orgId must be provided", http.StatusBadRequest)
|
|
+ return
|
|
+ }
|
|
+
|
|
+ logger.Info("Received org switch request to orgId: %s", req.OrgID)
|
|
+
|
|
+ // Send the request to the main goroutine
|
|
+ select {
|
|
+ case s.switchOrgChan <- req:
|
|
+ // Signal sent successfully
|
|
+ default:
|
|
+ // Channel already has a signal, don't block
|
|
+ http.Error(w, "Org switch already in progress", http.StatusTooManyRequests)
|
|
+ return
|
|
+ }
|
|
+
|
|
+ // Return a success response
|
|
+ w.Header().Set("Content-Type", "application/json")
|
|
+ w.WriteHeader(http.StatusAccepted)
|
|
+ json.NewEncoder(w).Encode(map[string]string{
|
|
+ "status": "org switch initiated",
|
|
+ "orgId": req.OrgID,
|
|
+ })
|
|
+}
|
|
diff --git a/olm/olm.go b/olm/olm.go
|
|
index 78080c4..5e292d6 100644
|
|
--- a/olm/olm.go
|
|
+++ b/olm/olm.go
|
|
@@ -58,6 +58,58 @@ type Config struct {
|
|
OrgID string
|
|
}
|
|
|
|
+// tunnelState holds all the active tunnel resources that need cleanup
|
|
+type tunnelState struct {
|
|
+ dev *device.Device
|
|
+ tdev tun.Device
|
|
+ uapiListener net.Listener
|
|
+ peerMonitor *peermonitor.PeerMonitor
|
|
+ stopRegister func()
|
|
+ connected bool
|
|
+}
|
|
+
|
|
+// teardownTunnel cleans up all tunnel resources
|
|
+func teardownTunnel(state *tunnelState) {
|
|
+ if state == nil {
|
|
+ return
|
|
+ }
|
|
+
|
|
+ logger.Info("Tearing down tunnel...")
|
|
+
|
|
+ // Stop registration messages
|
|
+ if state.stopRegister != nil {
|
|
+ state.stopRegister()
|
|
+ state.stopRegister = nil
|
|
+ }
|
|
+
|
|
+ // Stop peer monitor
|
|
+ if state.peerMonitor != nil {
|
|
+ state.peerMonitor.Stop()
|
|
+ state.peerMonitor = nil
|
|
+ }
|
|
+
|
|
+ // Close UAPI listener
|
|
+ if state.uapiListener != nil {
|
|
+ state.uapiListener.Close()
|
|
+ state.uapiListener = nil
|
|
+ }
|
|
+
|
|
+ // Close WireGuard device
|
|
+ if state.dev != nil {
|
|
+ state.dev.Close()
|
|
+ state.dev = nil
|
|
+ }
|
|
+
|
|
+ // Close TUN device
|
|
+ if state.tdev != nil {
|
|
+ state.tdev.Close()
|
|
+ state.tdev = nil
|
|
+ }
|
|
+
|
|
+ state.connected = false
|
|
+ logger.Info("Tunnel teardown complete")
|
|
+}
|
|
+
|
|
func Run(ctx context.Context, config Config) {
|
|
// Create a cancellable context for internal shutdown control
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
@@ -75,14 +127,14 @@ func Run(ctx context.Context, config Config) {
|
|
pingTimeout = config.PingTimeoutDuration
|
|
doHolepunch = config.Holepunch
|
|
privateKey wgtypes.Key
|
|
- connected bool
|
|
- dev *device.Device
|
|
wgData WgData
|
|
holePunchData HolePunchData
|
|
- uapiListener net.Listener
|
|
- tdev tun.Device
|
|
+ orgID = config.OrgID
|
|
)
|
|
|
|
+ // Tunnel state that can be torn down and recreated
|
|
+ tunnel := &tunnelState{}
|
|
+
|
|
stopHolepunch = make(chan struct{})
|
|
stopPing = make(chan struct{})
|
|
|
|
@@ -110,6 +162,7 @@ func Run(ctx context.Context, config Config) {
|
|
}
|
|
|
|
apiServer.SetVersion(config.Version)
|
|
+ apiServer.SetOrgID(orgID)
|
|
if err := apiServer.Start(); err != nil {
|
|
logger.Fatal("Failed to start HTTP server: %v", err)
|
|
}
|
|
@@ -249,14 +302,14 @@ func Run(ctx context.Context, config Config) {
|
|
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
|
logger.Debug("Received message: %v", msg.Data)
|
|
|
|
- if connected {
|
|
+ if tunnel.connected {
|
|
logger.Info("Already connected. Ignoring new connection request.")
|
|
return
|
|
}
|
|
|
|
- if stopRegister != nil {
|
|
- stopRegister()
|
|
- stopRegister = nil
|
|
+ if tunnel.stopRegister != nil {
|
|
+ tunnel.stopRegister()
|
|
+ tunnel.stopRegister = nil
|
|
}
|
|
|
|
close(stopHolepunch)
|
|
@@ -266,9 +319,9 @@ func Run(ctx context.Context, config Config) {
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// if there is an existing tunnel then close it
|
|
- if dev != nil {
|
|
+ if tunnel.dev != nil {
|
|
logger.Info("Got new message. Closing existing tunnel!")
|
|
- dev.Close()
|
|
+ tunnel.dev.Close()
|
|
}
|
|
|
|
jsonData, err := json.Marshal(msg.Data)
|
|
@@ -282,7 +335,7 @@ func Run(ctx context.Context, config Config) {
|
|
return
|
|
}
|
|
|
|
- tdev, err = func() (tun.Device, error) {
|
|
+ tunnel.tdev, err = func() (tun.Device, error) {
|
|
if runtime.GOOS == "darwin" {
|
|
interfaceName, err := findUnusedUTUN()
|
|
if err != nil {
|
|
@@ -301,7 +354,7 @@ func Run(ctx context.Context, config Config) {
|
|
return
|
|
}
|
|
|
|
- if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
|
+ if realInterfaceName, err2 := tunnel.tdev.Name(); err2 == nil {
|
|
interfaceName = realInterfaceName
|
|
}
|
|
|
|
@@ -321,9 +374,9 @@ func Run(ctx context.Context, config Config) {
|
|
return
|
|
}
|
|
|
|
- dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
|
+ tunnel.dev = device.NewDevice(tunnel.tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
|
|
|
- uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
|
+ tunnel.uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
|
if err != nil {
|
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
|
os.Exit(1)
|
|
@@ -331,16 +384,16 @@ func Run(ctx context.Context, config Config) {
|
|
|
|
go func() {
|
|
for {
|
|
- conn, err := uapiListener.Accept()
|
|
+ conn, err := tunnel.uapiListener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
- go dev.IpcHandle(conn)
|
|
+ go tunnel.dev.IpcHandle(conn)
|
|
}
|
|
}()
|
|
logger.Info("UAPI listener started")
|
|
|
|
- if err = dev.Up(); err != nil {
|
|
+ if err = tunnel.dev.Up(); err != nil {
|
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
|
}
|
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
|
@@ -350,7 +403,7 @@ func Run(ctx context.Context, config Config) {
|
|
apiServer.SetTunnelIP(wgData.TunnelIP)
|
|
}
|
|
|
|
- peerMonitor = peermonitor.NewPeerMonitor(
|
|
+ tunnel.peerMonitor = peermonitor.NewPeerMonitor(
|
|
func(siteID int, connected bool, rtt time.Duration) {
|
|
if apiServer != nil {
|
|
// Find the site config to get endpoint information
|
|
@@ -375,7 +428,7 @@ func Run(ctx context.Context, config Config) {
|
|
},
|
|
fixKey(privateKey.String()),
|
|
olm,
|
|
- dev,
|
|
+ tunnel.dev,
|
|
doHolepunch,
|
|
)
|
|
|
|
@@ -388,7 +441,7 @@ func Run(ctx context.Context, config Config) {
|
|
// Format the endpoint before configuring the peer.
|
|
site.Endpoint = formatEndpoint(site.Endpoint)
|
|
|
|
- if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
|
|
+ if err := ConfigurePeer(tunnel.dev, *site, privateKey, endpoint); err != nil {
|
|
logger.Error("Failed to configure peer: %v", err)
|
|
return
|
|
}
|
|
@@ -404,13 +457,13 @@ func Run(ctx context.Context, config Config) {
|
|
logger.Info("Configured peer %s", site.PublicKey)
|
|
}
|
|
|
|
- peerMonitor.Start()
|
|
+ tunnel.peerMonitor.Start()
|
|
|
|
if apiServer != nil {
|
|
apiServer.SetRegistered(true)
|
|
}
|
|
|
|
- connected = true
|
|
+ tunnel.connected = true
|
|
|
|
logger.Info("WireGuard device created.")
|
|
})
|
|
@@ -441,7 +494,7 @@ func Run(ctx context.Context, config Config) {
|
|
}
|
|
|
|
// Update the peer in WireGuard
|
|
- if dev != nil {
|
|
+ if tunnel.dev != nil {
|
|
// Find the existing peer to get old data
|
|
var oldRemoteSubnets string
|
|
var oldPublicKey string
|
|
@@ -456,7 +509,7 @@ func Run(ctx context.Context, config Config) {
|
|
// If the public key has changed, remove the old peer first
|
|
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
|
|
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
|
|
- if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
|
|
+ if err := RemovePeer(tunnel.dev, updateData.SiteId, oldPublicKey); err != nil {
|
|
logger.Error("Failed to remove old peer: %v", err)
|
|
return
|
|
}
|
|
@@ -465,7 +518,7 @@ func Run(ctx context.Context, config Config) {
|
|
// Format the endpoint before updating the peer.
|
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
|
|
|
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
|
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
|
logger.Error("Failed to update peer: %v", err)
|
|
return
|
|
}
|
|
@@ -524,11 +577,11 @@ func Run(ctx context.Context, config Config) {
|
|
}
|
|
|
|
// Add the peer to WireGuard
|
|
- if dev != nil {
|
|
+ if tunnel.dev != nil {
|
|
// Format the endpoint before adding the new peer.
|
|
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
|
|
|
|
- if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
|
+ if err := ConfigurePeer(tunnel.dev, siteConfig, privateKey, endpoint); err != nil {
|
|
logger.Error("Failed to add peer: %v", err)
|
|
return
|
|
}
|
|
@@ -585,8 +638,8 @@ func Run(ctx context.Context, config Config) {
|
|
}
|
|
|
|
// Remove the peer from WireGuard
|
|
- if dev != nil {
|
|
- if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
|
+ if tunnel.dev != nil {
|
|
+ if err := RemovePeer(tunnel.dev, removeData.SiteId, peerToRemove.PublicKey); err != nil {
|
|
logger.Error("Failed to remove peer: %v", err)
|
|
// Send error response if needed
|
|
return
|
|
@@ -640,7 +693,7 @@ func Run(ctx context.Context, config Config) {
|
|
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
|
}
|
|
|
|
- peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
|
+ tunnel.peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
|
})
|
|
|
|
olm.RegisterHandler("olm/register/no-sites", func(msg websocket.WSMessage) {
|
|
@@ -673,7 +726,7 @@ func Run(ctx context.Context, config Config) {
|
|
apiServer.SetConnectionStatus(true)
|
|
}
|
|
|
|
- if connected {
|
|
+ if tunnel.connected {
|
|
logger.Debug("Already connected, skipping registration")
|
|
return nil
|
|
}
|
|
@@ -682,11 +735,11 @@ func Run(ctx context.Context, config Config) {
|
|
|
|
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch)
|
|
|
|
- stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
"publicKey": publicKey.String(),
|
|
"relay": !doHolepunch,
|
|
"olmVersion": config.Version,
|
|
- "orgId": config.OrgID,
|
|
+ "orgId": orgID,
|
|
}, 1*time.Second)
|
|
|
|
go keepSendingPing(olm)
|
|
@@ -705,6 +758,49 @@ func Run(ctx context.Context, config Config) {
|
|
}
|
|
defer olm.Close()
|
|
|
|
+ // Listen for org switch requests from the API (after olm is created)
|
|
+ if apiServer != nil {
|
|
+ go func() {
|
|
+ for req := range apiServer.GetSwitchOrgChannel() {
|
|
+ logger.Info("Org switch requested via API to orgId: %s", req.OrgID)
|
|
+
|
|
+ // Update the orgId
|
|
+ orgID = req.OrgID
|
|
+
|
|
+ // Teardown existing tunnel
|
|
+ teardownTunnel(tunnel)
|
|
+
|
|
+ // Reset tunnel state
|
|
+ tunnel = &tunnelState{}
|
|
+
|
|
+ // Stop holepunch
|
|
+ select {
|
|
+ case <-stopHolepunch:
|
|
+ // Channel already closed
|
|
+ default:
|
|
+ close(stopHolepunch)
|
|
+ }
|
|
+ stopHolepunch = make(chan struct{})
|
|
+
|
|
+ // Clear API server state
|
|
+ apiServer.SetRegistered(false)
|
|
+ apiServer.SetTunnelIP("")
|
|
+ apiServer.SetOrgID(orgID)
|
|
+
|
|
+ // Send new registration message with updated orgId
|
|
+ publicKey := privateKey.PublicKey()
|
|
+ logger.Info("Sending registration message with new orgId: %s", orgID)
|
|
+
|
|
+ tunnel.stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
|
+ "publicKey": publicKey.String(),
|
|
+ "relay": !doHolepunch,
|
|
+ "olmVersion": config.Version,
|
|
+ "orgId": orgID,
|
|
+ }, 1*time.Second)
|
|
+ }
|
|
+ }()
|
|
+ }
|
|
+
|
|
select {
|
|
case <-ctx.Done():
|
|
logger.Info("Context cancelled")
|
|
@@ -717,9 +813,9 @@ func Run(ctx context.Context, config Config) {
|
|
close(stopHolepunch)
|
|
}
|
|
|
|
- if stopRegister != nil {
|
|
- stopRegister()
|
|
- stopRegister = nil
|
|
+ if tunnel.stopRegister != nil {
|
|
+ tunnel.stopRegister()
|
|
+ tunnel.stopRegister = nil
|
|
}
|
|
|
|
select {
|
|
@@ -729,16 +825,8 @@ func Run(ctx context.Context, config Config) {
|
|
close(stopPing)
|
|
}
|
|
|
|
- if peerMonitor != nil {
|
|
- peerMonitor.Stop()
|
|
- }
|
|
-
|
|
- if uapiListener != nil {
|
|
- uapiListener.Close()
|
|
- }
|
|
- if dev != nil {
|
|
- dev.Close()
|
|
- }
|
|
+ // Use teardownTunnel to clean up all tunnel resources
|
|
+ teardownTunnel(tunnel)
|
|
|
|
if apiServer != nil {
|
|
apiServer.Stop()
|