mirror of
https://github.com/fosrl/newt.git
synced 2026-03-09 12:16:39 +00:00
Add update message
This commit is contained in:
115
wg/wg.go
115
wg/wg.go
@@ -169,6 +169,7 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||||
|
wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -515,6 +516,120 @@ func (s *WireGuardService) removePeer(publicKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
|
||||||
|
// Define a struct to match the incoming message structure with optional fields
|
||||||
|
type UpdatePeerRequest struct {
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
AllowedIPs []string `json:"allowedIps,omitempty"`
|
||||||
|
Endpoint string `json:"endpoint,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var request UpdatePeerRequest
|
||||||
|
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||||
|
logger.Info("Error unmarshaling peer data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, get the current peer configuration to preserve any unmodified fields
|
||||||
|
device, err := s.wgClient.Device(s.interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error getting WireGuard device: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey, err := wgtypes.ParseKey(request.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error parsing public key: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the existing peer configuration
|
||||||
|
var currentPeer *wgtypes.Peer
|
||||||
|
for _, p := range device.Peers {
|
||||||
|
if p.PublicKey == pubKey {
|
||||||
|
currentPeer = &p
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentPeer == nil {
|
||||||
|
logger.Info("Peer %s not found, cannot update", request.PublicKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the update peer config
|
||||||
|
peerConfig := wgtypes.PeerConfig{
|
||||||
|
PublicKey: pubKey,
|
||||||
|
UpdateOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep the default persistent keepalive of 1 second
|
||||||
|
keepalive := time.Second
|
||||||
|
peerConfig.PersistentKeepaliveInterval = &keepalive
|
||||||
|
|
||||||
|
// Only update AllowedIPs if provided in the request
|
||||||
|
if request.AllowedIPs != nil && len(request.AllowedIPs) > 0 {
|
||||||
|
var allowedIPs []net.IPNet
|
||||||
|
for _, ipStr := range request.AllowedIPs {
|
||||||
|
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error parsing allowed IP %s: %v", ipStr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
allowedIPs = append(allowedIPs, *ipNet)
|
||||||
|
}
|
||||||
|
peerConfig.AllowedIPs = allowedIPs
|
||||||
|
logger.Info("Updating AllowedIPs for peer %s", request.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Endpoint field special case
|
||||||
|
// If Endpoint is included in the request but empty, we want to remove the endpoint
|
||||||
|
// If Endpoint is not included, we don't modify it
|
||||||
|
endpointSpecified := false
|
||||||
|
for key := range msg.Data.(map[string]interface{}) {
|
||||||
|
if key == "endpoint" {
|
||||||
|
endpointSpecified = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if endpointSpecified {
|
||||||
|
if request.Endpoint != "" {
|
||||||
|
// Update to new endpoint
|
||||||
|
endpoint, err := net.ResolveUDPAddr("udp", request.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error resolving endpoint address %s: %v", request.Endpoint, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peerConfig.Endpoint = endpoint
|
||||||
|
logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint)
|
||||||
|
} else {
|
||||||
|
// Request contained endpoint field but it was empty/null - remove endpoint
|
||||||
|
// To remove an endpoint in WireGuard, we set it to nil and specify ReplaceAllowedIPs
|
||||||
|
peerConfig.Endpoint = nil
|
||||||
|
logger.Info("Removing Endpoint for peer %s", request.PublicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the configuration update
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||||
|
logger.Info("Error updating peer configuration: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Peer %s updated successfully", request.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|||||||
Reference in New Issue
Block a user