From 09d6829f8b552997af5b3c9f2372a3f4ea160244 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 31 Mar 2025 15:46:01 -0400 Subject: [PATCH] Add update message --- wg/wg.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/wg/wg.go b/wg/wg.go index 9e78624..4322756 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -169,6 +169,7 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) + wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) return service, nil } @@ -515,6 +516,120 @@ func (s *WireGuardService) removePeer(publicKey string) error { return nil } +func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { + // Define a struct to match the incoming message structure with optional fields + type UpdatePeerRequest struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + var request UpdatePeerRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling peer data: %v", err) + return + } + + // First, get the current peer configuration to preserve any unmodified fields + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + logger.Info("Error getting WireGuard device: %v", err) + return + } + + pubKey, err := wgtypes.ParseKey(request.PublicKey) + if err != nil { + logger.Info("Error parsing public key: %v", err) + return + } + + // Find the existing peer configuration + var currentPeer *wgtypes.Peer + for _, p := range device.Peers { + if p.PublicKey == pubKey { + currentPeer = &p + break + } + } + + if currentPeer == nil { + logger.Info("Peer %s not found, cannot update", request.PublicKey) + return + } + + // Create the update peer config + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + UpdateOnly: true, + } + + // Keep the default persistent keepalive of 1 second + keepalive := time.Second + peerConfig.PersistentKeepaliveInterval = &keepalive + + // Only update AllowedIPs if provided in the request + if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 { + var allowedIPs []net.IPNet + for _, ipStr := range request.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + logger.Info("Error parsing allowed IP %s: %v", ipStr, err) + return + } + allowedIPs = append(allowedIPs, *ipNet) + } + peerConfig.AllowedIPs = allowedIPs + logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) + } + + // Handle Endpoint field special case + // If Endpoint is included in the request but empty, we want to remove the endpoint + // If Endpoint is not included, we don't modify it + endpointSpecified := false + for key := range msg.Data.(map[string]interface{}) { + if key == "endpoint" { + endpointSpecified = true + break + } + } + + if endpointSpecified { + if request.Endpoint != "" { + // Update to new endpoint + endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint) + if err != nil { + logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err) + return + } + peerConfig.Endpoint = endpoint + logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) + } else { + // Request contained endpoint field but it was empty/null - remove endpoint + // To remove an endpoint in WireGuard, we set it to nil and specify ReplaceAllowedIPs + peerConfig.Endpoint = nil + logger.Info("Removing Endpoint for peer %s", request.PublicKey) + } + } + + // Apply the configuration update + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { + logger.Info("Error updating peer configuration: %v", err) + return + } + + logger.Info("Peer %s updated successfully", request.PublicKey) +} + func (s *WireGuardService) periodicBandwidthCheck() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop()