From ad8a94fdc8ecb52af501cedbfca1109db1447718 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 13 Apr 2025 21:28:25 -0400 Subject: [PATCH] newts not being on when olm is started --- common.go | 113 ++++++++++++++++++++ main.go | 214 +++++++++++++++++++++++++------------ peermonitor/peermonitor.go | 4 +- 3 files changed, 258 insertions(+), 73 deletions(-) diff --git a/common.go b/common.go index a4727d8..d06d860 100644 --- a/common.go +++ b/common.go @@ -76,6 +76,35 @@ type fixedPortBind struct { conn.Bind } +// PeerAction represents a request to add, update, or remove a peer +type PeerAction struct { + Action string `json:"action"` // "add", "update", or "remove" + SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information +} + +// UpdatePeerData represents the data needed to update a peer +type UpdatePeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` +} + +// AddPeerData represents the data needed to add a peer +type AddPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` +} + +// RemovePeerData represents the data needed to remove a peer +type RemovePeerData struct { + SiteId int `json:"siteId"` +} + func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { // Ignore the requested port and use our fixed port return b.Bind.Open(b.port) @@ -421,3 +450,87 @@ func keepSendingPing(olm *websocket.Client) { } } } + +// ConfigurePeer sets up or updates a peer within the WireGuard device +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { + siteHost, err := resolveDomain(siteConfig.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) + } + + // Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP + allowedIp := strings.Split(siteConfig.ServerIP, "/") + if len(allowedIp) > 1 { + allowedIp[1] = "32" + } else { + allowedIp = append(allowedIp, "32") + } + allowedIpStr := strings.Join(allowedIp, "/") + + // Construct WireGuard config for this peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey))) + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIpStr)) + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) + configBuilder.WriteString("persistent_keepalive_interval=1\n") + + config := configBuilder.String() + logger.Debug("Configuring peer with config: %s", config) + + err = dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard peer: %v", err) + } + + // Set up peer monitoring + if peerMonitor != nil { + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] + monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, siteConfig.ServerPort+1) // +1 for the monitor port + + primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + wgConfig := &peermonitor.WireGuardConfig{ + SiteID: siteConfig.SiteId, + PublicKey: fixKey(siteConfig.PublicKey), + ServerIP: strings.Split(siteConfig.ServerIP, "/")[0], + Endpoint: siteConfig.Endpoint, + PrimaryRelay: primaryRelay, + } + + err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig) + if err != nil { + logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) + } else { + logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) + } + } + + return nil +} + +// RemovePeer removes a peer from the WireGuard device +func RemovePeer(dev *device.Device, siteId int, publicKey string) error { + // Construct WireGuard config to remove the peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(publicKey))) + configBuilder.WriteString("remove=true\n") + + config := configBuilder.String() + logger.Debug("Removing peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to remove WireGuard peer: %v", err) + } + + // Stop monitoring this peer + if peerMonitor != nil { + peerMonitor.RemovePeer(siteId) + logger.Info("Stopped monitoring for site %d", siteId) + } + + return nil +} diff --git a/main.go b/main.go index 5d0d003..0935d0b 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,6 @@ import ( "regexp" "runtime" "strconv" - "strings" "syscall" "time" @@ -216,22 +215,21 @@ func main() { os.Exit(1) } - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - olm.Close() - }) + olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { + logger.Info("Received message: %v", msg.Data) - olm.RegisterHandler("olm/wg/update", func(msg websocket.WSMessage) { jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) return } - if err := json.Unmarshal(jsonData, &wgData); err != nil { + if err := json.Unmarshal(jsonData, &holePunchData); err != nil { logger.Info("Error unmarshaling target data: %v", err) return } + + gerbilServerPubKey = holePunchData.ServerPubKey }) connectTimes := 0 @@ -357,11 +355,6 @@ func main() { logger.Info("UAPI listener started") - primaryRelay, err := resolveDomain(endpoint) - if err != nil { - logger.Warn("Failed to resolve endpoint: %v", err) - } - peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { if connected { @@ -375,62 +368,14 @@ func main() { dev, ) - // Configure WireGuard with all sites as peers - var configBuilder strings.Builder - - // Start with the private key - configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String()))) - - // Add each site as a peer + // loop over the sites and call ConfigurePeer for each one for _, site := range wgData.Sites { - siteHost, err := resolveDomain(site.Endpoint) + err = ConfigurePeer(dev, site, privateKey, endpoint) if err != nil { - logger.Warn("Failed to resolve endpoint for site %d: %v", site.SiteId, err) - continue + logger.Error("Failed to configure peer: %v", err) + return } - - // split off the cidr of the server ip which is just a string and add /32 for the allowed ip - allowedIp := strings.Split(site.ServerIP, "/") - if len(allowedIp) > 1 { - allowedIp[1] = "32" - } else { - allowedIp = append(allowedIp, "32") - } - allowedIpStr := strings.Join(allowedIp, "/") - - // Include peer info - configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(site.PublicKey))) - configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIpStr)) - configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=1\n") - - // take the first part of the allowedIp and the port from the endpoint and put them together - monitorAddress := strings.Split(site.ServerIP, "/")[0] - - monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, site.ServerPort+1) // +1 for the monitor port - - wgConfig := &peermonitor.WireGuardConfig{ - SiteID: site.SiteId, - PublicKey: fixKey(site.PublicKey), - ServerIP: strings.Split(site.ServerIP, "/")[0], - Endpoint: site.Endpoint, - PrimaryRelay: primaryRelay, // Use the main endpoint as relay - } - - err = peerMonitor.AddPeer(site.SiteId, monitorPeer, wgConfig) - if err != nil { - logger.Warn("Failed to setup monitoring for site %d: %v", site.SiteId, err) - } else { - logger.Info("Started monitoring for site %d at %s", site.SiteId, monitorPeer) - } - } - - config := configBuilder.String() - logger.Debug("WireGuard config: %s", config) - - err = dev.IpcSet(config) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v", err) + logger.Info("Configured peer %s", site.PublicKey) } // Bring up the device @@ -452,21 +397,148 @@ func main() { logger.Info("WireGuard device created.") }) - olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { - logger.Info("Received message: %v", msg.Data) + olm.RegisterHandler("olm/wg/peer/update", func(msg websocket.WSMessage) { + logger.Info("Received update-peer message") jsonData, err := json.Marshal(msg.Data) if err != nil { - logger.Info("Error marshaling data: %v", err) + logger.Error("Error marshaling data: %v", err) return } - if err := json.Unmarshal(jsonData, &holePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) + var updateData UpdatePeerData + if err := json.Unmarshal(jsonData, &updateData); err != nil { + logger.Error("Error unmarshaling update data: %v", err) return } - gerbilServerPubKey = holePunchData.ServerPubKey + // Convert to SiteConfig + siteConfig := SiteConfig{ + SiteId: updateData.SiteId, + Endpoint: updateData.Endpoint, + PublicKey: updateData.PublicKey, + ServerIP: updateData.ServerIP, + ServerPort: updateData.ServerPort, + } + + // Update the peer in WireGuard + if dev != nil { + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to update peer: %v", err) + // Send error response if needed + return + } + + // Update successful + logger.Info("Successfully updated peer for site %d", updateData.SiteId) + // If this is part of a WgData structure, update it + for i, site := range wgData.Sites { + if site.SiteId == updateData.SiteId { + wgData.Sites[i] = siteConfig + break + } + } + } else { + logger.Error("WireGuard device not initialized") + } + }) + + // Handler for adding a new peer + olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { + logger.Info("Received add-peer message") + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var addData AddPeerData + if err := json.Unmarshal(jsonData, &addData); err != nil { + logger.Error("Error unmarshaling add data: %v", err) + return + } + + // Convert to SiteConfig + siteConfig := SiteConfig{ + SiteId: addData.SiteId, + Endpoint: addData.Endpoint, + PublicKey: addData.PublicKey, + ServerIP: addData.ServerIP, + ServerPort: addData.ServerPort, + } + + // Add the peer to WireGuard + if dev != nil { + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + + // Add successful + logger.Info("Successfully added peer for site %d", addData.SiteId) + + // Update WgData with the new peer + wgData.Sites = append(wgData.Sites, siteConfig) + } else { + logger.Error("WireGuard device not initialized") + } + }) + + // Handler for removing a peer + olm.RegisterHandler("olm/wg/peer/remove", func(msg websocket.WSMessage) { + logger.Info("Received remove-peer message") + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var removeData RemovePeerData + if err := json.Unmarshal(jsonData, &removeData); err != nil { + logger.Error("Error unmarshaling remove data: %v", err) + return + } + + // Find the peer to remove + var peerToRemove *SiteConfig + var newSites []SiteConfig + + for _, site := range wgData.Sites { + if site.SiteId == removeData.SiteId { + peerToRemove = &site + } else { + newSites = append(newSites, site) + } + } + + if peerToRemove == nil { + logger.Error("Peer with site ID %d not found", removeData.SiteId) + return + } + + // Remove the peer from WireGuard + if dev != nil { + if err := RemovePeer(dev, removeData.SiteId, peerToRemove.PublicKey); err != nil { + logger.Error("Failed to remove peer: %v", err) + // Send error response if needed + return + } + + // Remove successful + logger.Info("Successfully removed peer for site %d", removeData.SiteId) + + // Update WgData to remove the peer + wgData.Sites = newSites + } else { + logger.Error("WireGuard device not initialized") + } + }) + + olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { + logger.Info("Received terminate message") + olm.Close() }) olm.OnConnect(func() error { diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index be17717..09dade6 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -46,8 +46,8 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w configs: make(map[int]*WireGuardConfig), callback: callback, interval: 1 * time.Second, // Default check interval - timeout: 500 * time.Millisecond, - maxAttempts: 3, + timeout: 1000 * time.Millisecond, + maxAttempts: 5, privateKey: privateKey, wsClient: wsClient, device: device,