diff --git a/common.go b/common.go index d0bea54..57826fc 100644 --- a/common.go +++ b/common.go @@ -110,6 +110,12 @@ type RemovePeerData struct { SiteId int `json:"siteId"` } +type RelayPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { // Ignore the requested port and use our fixed port return b.Bind.Open(b.port) diff --git a/main.go b/main.go index 2f50f18..f0ace5a 100644 --- a/main.go +++ b/main.go @@ -220,7 +220,7 @@ func main() { } olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { - logger.Info("Received message: %v", msg.Data) + logger.Debug("Received message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { @@ -239,7 +239,7 @@ func main() { connectTimes := 0 // Register handlers for different message types olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Info("Received message: %v", msg.Data) + logger.Debug("Received message: %v", msg.Data) if connectTimes > 0 { logger.Info("Already connected. Ignoring new connection request.") @@ -405,7 +405,7 @@ func main() { }) olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { - logger.Info("Received update-peer message: %v", msg.Data) + logger.Debug("Received update-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { @@ -452,7 +452,7 @@ func main() { // Handler for adding a new peer olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { - logger.Info("Received add-peer message: %v", msg.Data) + logger.Debug("Received add-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { @@ -506,7 +506,7 @@ func main() { // Handler for removing a peer olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { - logger.Info("Received remove-peer message: %v", msg.Data) + logger.Debug("Received remove-peer message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) if err != nil { @@ -567,6 +567,29 @@ func main() { } }) + olm.RegisterHandler("olm/wg/peer/relay", func(msg websocket.WSMessage) { + logger.Debug("Received relay-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData RelayPeerData + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + primaryRelay, err := resolveDomain(removeData.Endpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + peerMonitor.HandleFailover(removeData.SiteId, primaryRelay) + }) + olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") olm.Close() diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index 4f523b6..9570aec 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -179,12 +179,15 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status wgtester. // If disconnected, handle failover if !status.Connected { - pm.handleFailover(siteID) + // Send relay message to the server + if pm.wsClient != nil { + pm.sendRelay(siteID) + } } } // handleFailover handles failover to the relay server when a peer is disconnected -func (pm *PeerMonitor) handleFailover(siteID int) { +func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) { pm.mutex.Lock() config, exists := pm.configs[siteID] pm.mutex.Unlock() @@ -198,7 +201,7 @@ func (pm *PeerMonitor) handleFailover(siteID int) { public_key=%s allowed_ip=%s/32 endpoint=%s:21820 -persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, config.PrimaryRelay) +persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, relayEndpoint) err := pm.device.IpcSet(wgConfig) if err != nil { @@ -207,11 +210,6 @@ persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.Server } logger.Info("Adjusted peer %d to point to relay!\n", siteID) - - // Send relay message to the server - if pm.wsClient != nil { - pm.sendRelay(siteID) - } } // sendRelay sends a relay message to the server