From 40dfab31a5ebdbc500900525f90875373ac04884 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 25 Jul 2025 10:50:02 -0700 Subject: [PATCH] Maybe basic func --- linux.go | 8 +- wgnetstack/wgnetstack.go | 879 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 884 insertions(+), 3 deletions(-) create mode 100644 wgnetstack/wgnetstack.go diff --git a/linux.go b/linux.go index 76e33c6..a769b5a 100644 --- a/linux.go +++ b/linux.go @@ -9,11 +9,13 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" - "github.com/fosrl/newt/wg" + + // "github.com/fosrl/newt/wg" + "github.com/fosrl/newt/wgnetstack" "github.com/fosrl/newt/wgtester" ) -var wgService *wg.WireGuardService +var wgService *wgnetstack.WireGuardService var wgTesterServer *wgtester.Server func setupClients(client *websocket.Client) { @@ -27,7 +29,7 @@ func setupClients(client *websocket.Client) { host = strings.TrimSuffix(host, "/") // Create WireGuard service - wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) + wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "8.8.8.8") if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go new file mode 100644 index 0000000..72da713 --- /dev/null +++ b/wgnetstack/wgnetstack.go @@ -0,0 +1,879 @@ +package wgnetstack + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + mathrand "math/rand/v2" + "net" + "net/netip" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" + "github.com/fosrl/newt/websocket" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type WgConfig struct { + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` + Endpoint string `json:"endpoint"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +type PeerReading struct { + BytesReceived int64 + BytesTransmitted int64 + LastChecked time.Time +} + +type WireGuardService struct { + interfaceName string + mtu int + client *websocket.Client + config WgConfig + key wgtypes.Key + keyFilePath string + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + stopHolepunch chan struct{} + host string + serverPubKey string + holePunchEndpoint string + token string + stopGetConfig func() + // Netstack fields + tun tun.Device + tnet *netstack.Net + device *device.Device + dns []netip.Addr +} + +// Add this type definition +type fixedPortBind struct { + port uint16 + conn.Bind +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + +// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + // We need to check port+1 as well, so adjust the max port to avoid going out of range + adjustedMaxPort := maxPort - 1 + if adjustedMaxPort < minPort { + return 0, fmt.Errorf("insufficient port range to find consecutive ports: min=%d, max=%d", minPort, maxPort) + } + + // Create a slice of all ports in the range (excluding the last one) + portRange := make([]uint16, adjustedMaxPort-minPort+1) + for i := range portRange { + portRange[i] = minPort + uint16(i) + } + + // Fisher-Yates shuffle to randomize the port order + for i := len(portRange) - 1; i > 0; i-- { + j := mathrand.IntN(i + 1) + portRange[i], portRange[j] = portRange[j], portRange[i] + } + + // Try each port in the randomized order + for _, port := range portRange { + // Check if port is available + addr1 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + conn1, err1 := net.ListenUDP("udp", addr1) + if err1 != nil { + continue // Port is in use or there was an error, try next port + } + + // Check if port+1 is also available + addr2 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port + 1), + } + conn2, err2 := net.ListenUDP("udp", addr2) + if err2 != nil { + // The next port is not available, so close the first connection and try again + conn1.Close() + continue + } + + // Both ports are available, close connections and return the first port + conn1.Close() + conn2.Close() + return port, nil + } + + return 0, fmt.Errorf("no available consecutive UDP ports found in range %d-%d", minPort, maxPort) +} + +func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) { + var key wgtypes.Key + + // Load or generate private key + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + // Generate a new private key + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } + // Save the key to the file + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + return nil, fmt.Errorf("failed to save private key: %v", err) + } + } else { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + return nil, fmt.Errorf("failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData))) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + } + + // Find an available port + port, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + return nil, fmt.Errorf("error finding available port: %v", err) + } + + // Parse DNS addresses + dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)} + + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + key: key, + keyFilePath: generateAndSaveKeyTo, + newtId: newtId, + host: host, + lastReadings: make(map[string]PeerReading), + stopHolepunch: make(chan struct{}), + Port: port, + dns: dnsAddrs, + } + + // Register websocket handlers + wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) + wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) + wsClient.RegisterHandler("newt/wg/peer/update", service.handleUpdatePeer) + + return service, nil +} + +func (s *WireGuardService) Close(rm bool) { + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil + } + + // Close WireGuard device first - this will automatically close the TUN device + if s.device != nil { + s.device.Close() + s.device = nil + } + + // Clear references but don't manually close since device.Close() already did it + if s.tnet != nil { + s.tnet = nil + } + if s.tun != nil { + s.tun = nil // Don't call tun.Close() here since device.Close() already closed it + } +} + +func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) { + s.serverPubKey = serverPubKey + s.holePunchEndpoint = endpoint + + logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) + + // start the UDP holepunch + go s.keepSendingUDPHolePunch(s.holePunchEndpoint) +} + +func (s *WireGuardService) SetToken(token string) { + s.token = token +} + +// GetNetstackNet returns the netstack network interface for use by other components +func (s *WireGuardService) GetNetstackNet() *netstack.Net { + return s.tnet +} + +// IsReady returns true if the WireGuard service is ready to use +func (s *WireGuardService) IsReady() bool { + return s.device != nil && s.tnet != nil +} + +// GetPublicKey returns the public key of this WireGuard service +func (s *WireGuardService) GetPublicKey() wgtypes.Key { + return s.key.PublicKey() +} + +func (s *WireGuardService) LoadRemoteConfig() error { + s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ + "publicKey": s.key.PublicKey().String(), + "port": s.Port, + }, 2*time.Second) + + logger.Info("Requesting WireGuard configuration from remote server") + go s.periodicBandwidthCheck() + + return nil +} + +func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { + var config WgConfig + + logger.Debug("Received message: %v", msg) + logger.Info("Received WireGuard clients configuration from remote server") + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + s.config = config + + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil + } + + // Ensure the WireGuard interface and peers are configured + if err := s.ensureWireguardInterface(config); err != nil { + logger.Error("Failed to ensure WireGuard interface: %v", err) + } + + if err := s.ensureWireguardPeers(config.Peers); err != nil { + logger.Error("Failed to ensure WireGuard peers: %v", err) + } +} + +func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { + + // split off the cidr from the IP address + parts := strings.Split(wgconfig.IpAddress, "/") + if len(parts) != 2 { + return fmt.Errorf("invalid IP address format: %s", wgconfig.IpAddress) + } + // Parse the IP address and CIDR mask + tunnelIP := netip.MustParseAddr(parts[0]) + + // Parse the IP address from the config + // tunnelIP := netip.MustParseAddr(wgconfig.IpAddress) + + // Create TUN device and network stack using netstack + var err error + s.tun, s.tnet, err = netstack.CreateNetTUN( + []netip.Addr{tunnelIP}, + s.dns, + s.mtu) + if err != nil { + return fmt.Errorf("failed to create TUN device: %v", err) + } + + // Create WireGuard device + s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger( + device.LogLevelSilent, // Use silent logging by default - could be made configurable + "wireguard: ", + )) + + logger.Info("Private key is %s", fixKey(s.key.String())) + + // Configure WireGuard with private key + config := fmt.Sprintf("private_key=%s", fixKey(s.key.String())) + + err = s.device.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // Bring up the device + err = s.device.Up() + if err != nil { + return fmt.Errorf("failed to bring up WireGuard device: %v", err) + } + + logger.Info("WireGuard netstack device created and configured") + return nil +} + +func fixKey(key string) string { + // Remove any whitespace + key = strings.TrimSpace(key) + + // Decode from base64 + decoded, err := base64.StdEncoding.DecodeString(key) + if err != nil { + logger.Fatal("Error decoding base64: %v", err) + } + + // Convert to hex + return hex.EncodeToString(decoded) +} + +func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { + // For netstack, we need to manage peers differently + // We'll configure peers directly on the device using IPC + + // First, clear all existing peers by getting current config and removing them + currentConfig, err := s.device.IpcGet() + if err != nil { + return fmt.Errorf("failed to get current device config: %v", err) + } + + // Parse current peers and remove them + lines := strings.Split(currentConfig, "\n") + var currentPeerKeys []string + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + pubKey := strings.TrimPrefix(line, "public_key=") + currentPeerKeys = append(currentPeerKeys, pubKey) + } + } + + // Remove existing peers + for _, pubKey := range currentPeerKeys { + removeConfig := fmt.Sprintf("public_key=%s\nremove=true", pubKey) + if err := s.device.IpcSet(removeConfig); err != nil { + logger.Warn("Failed to remove peer %s: %v", pubKey, err) + } + } + + // Add new peers + for _, peer := range peers { + if err := s.addPeerToDevice(peer); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + + return nil +} + +func (s *WireGuardService) addPeerToDevice(peer Peer) error { + // parse the key first + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // Build IPC configuration string for the peer + config := fmt.Sprintf("public_key=%s", fixKey(pubKey.String())) + + // Add allowed IPs + for _, allowedIP := range peer.AllowedIPs { + config += fmt.Sprintf("\nallowed_ip=%s", allowedIP) + } + + // Add endpoint if specified + if peer.Endpoint != "" { + config += fmt.Sprintf("\nendpoint=%s", peer.Endpoint) + } + + // Add persistent keepalive + config += "\npersistent_keepalive_interval=25" + + // Apply the configuration + if err := s.device.IpcSet(config); err != nil { + return fmt.Errorf("failed to configure peer: %v", err) + } + + logger.Info("Peer %s added successfully", peer.PublicKey) + return nil +} + +func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + var peer Peer + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &peer); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + err = s.addPeerToDevice(peer) + if err != nil { + logger.Info("Error adding peer: %v", err) + return + } +} + +func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } + type RemoveRequest struct { + PublicKey string `json:"publicKey"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + var request RemoveRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling data: %v", err) + return + } + + if err := s.removePeer(request.PublicKey); err != nil { + logger.Info("Error removing peer: %v", err) + return + } +} + +func (s *WireGuardService) removePeer(publicKey string) error { + + // Parse the public key + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // Build IPC configuration string to remove the peer + config := fmt.Sprintf("public_key=%s\nremove=true", fixKey(pubKey.String())) + + if err := s.device.IpcSet(config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + logger.Info("Peer %s removed successfully", publicKey) + return nil +} + +func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { + logger.Debug("Received message: %v", msg.Data) + // Define a struct to match the incoming message structure with optional fields + type UpdatePeerRequest struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + var request UpdatePeerRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling peer data: %v", err) + return + } + + // Parse the public key + pubKey, err := wgtypes.ParseKey(request.PublicKey) + if err != nil { + logger.Info("Failed to parse public key: %v", err) + return + } + + // Build IPC configuration string to update the peer + config := fmt.Sprintf("public_key=%s\nupdate_only=true", fixKey(pubKey.String())) + + // Handle AllowedIPs update + if len(request.AllowedIPs) > 0 { + config += "\nreplace_allowed_ips=true" + for _, allowedIP := range request.AllowedIPs { + config += fmt.Sprintf("\nallowed_ip=%s", allowedIP) + } + logger.Info("Updating AllowedIPs for peer %s", request.PublicKey) + } + + // Handle Endpoint field special case + endpointSpecified := false + for key := range msg.Data.(map[string]interface{}) { + if key == "endpoint" { + endpointSpecified = true + break + } + } + + if endpointSpecified { + if request.Endpoint != "" { + config += fmt.Sprintf("\nendpoint=%s", request.Endpoint) + logger.Info("Updating Endpoint for peer %s to %s", request.PublicKey, request.Endpoint) + } else { + config += "\nendpoint=0.0.0.0:0" // Remove endpoint + logger.Info("Removing Endpoint for peer %s", request.PublicKey) + } + } + + // Always set persistent keepalive + config += "\npersistent_keepalive_interval=25" + + // Apply the configuration update + if err := s.device.IpcSet(config); err != nil { + logger.Info("Error updating peer configuration: %v", err) + return + } + + logger.Info("Peer %s updated successfully", request.PublicKey) +} + +func (s *WireGuardService) periodicBandwidthCheck() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := s.reportPeerBandwidth(); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + } +} + +func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { + if s.device == nil { + return []PeerBandwidth{}, nil + } + + // Get device statistics using IPC + stats, err := s.device.IpcGet() + if err != nil { + return nil, fmt.Errorf("failed to get device statistics: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + // Parse the IPC response to extract peer statistics + lines := strings.Split(stats, "\n") + var currentPubKey string + var rxBytes, txBytes int64 + + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + // Process previous peer if we have one + if currentPubKey != "" { + bandwidth := s.processPeerBandwidth(currentPubKey, rxBytes, txBytes, now) + if bandwidth != nil { + peerBandwidths = append(peerBandwidths, *bandwidth) + } + } + // Start new peer + currentPubKey = strings.TrimPrefix(line, "public_key=") + rxBytes = 0 + txBytes = 0 + } else if strings.HasPrefix(line, "rx_bytes=") { + rxBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, "rx_bytes="), 10, 64) + } else if strings.HasPrefix(line, "tx_bytes=") { + txBytes, _ = strconv.ParseInt(strings.TrimPrefix(line, "tx_bytes="), 10, 64) + } + } + + // Process the last peer + if currentPubKey != "" { + bandwidth := s.processPeerBandwidth(currentPubKey, rxBytes, txBytes, now) + if bandwidth != nil { + peerBandwidths = append(peerBandwidths, *bandwidth) + } + } + + // Clean up old peers + devicePeers := make(map[string]bool) + lines = strings.Split(stats, "\n") + for _, line := range lines { + if strings.HasPrefix(line, "public_key=") { + pubKey := strings.TrimPrefix(line, "public_key=") + devicePeers[pubKey] = true + } + } + + for publicKey := range s.lastReadings { + if !devicePeers[publicKey] { + delete(s.lastReadings, publicKey) + } + } + + // parse the public keys and have them as base64 in the opposite order to fixKey + for i := range peerBandwidths { + pubKeyBytes, err := base64.StdEncoding.DecodeString(peerBandwidths[i].PublicKey) + if err != nil { + logger.Info("Failed to decode public key %s: %v", peerBandwidths[i].PublicKey, err) + continue + } + // Convert to hex + peerBandwidths[i].PublicKey = hex.EncodeToString(pubKeyBytes) + } + + return peerBandwidths, nil +} + +func (s *WireGuardService) processPeerBandwidth(publicKey string, rxBytes, txBytes int64, now time.Time) *PeerBandwidth { + currentReading := PeerReading{ + BytesReceived: rxBytes, + BytesTransmitted: txBytes, + LastChecked: now, + } + + var bytesInDiff, bytesOutDiff float64 + lastReading, exists := s.lastReadings[publicKey] + + if exists { + timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() + if timeDiff > 0 { + // Calculate bytes transferred since last reading + bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + + // Handle counter wraparound (if the counter resets or overflows) + if bytesInDiff < 0 { + bytesInDiff = float64(currentReading.BytesReceived) + } + if bytesOutDiff < 0 { + bytesOutDiff = float64(currentReading.BytesTransmitted) + } + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + // Update the last reading + s.lastReadings[publicKey] = currentReading + + return &PeerBandwidth{ + PublicKey: publicKey, + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + } + } + } + + // For first reading or if readings are too close together, report 0 + s.lastReadings[publicKey] = currentReading + return &PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + } +} + +func (s *WireGuardService) reportPeerBandwidth() error { + bandwidths, err := s.calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ + "bandwidthData": bandwidths, + }) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + + return nil +} + +func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { + + if s.serverPubKey == "" || s.token == "" { + logger.Debug("Server public key or token not set, skipping UDP hole punch") + return nil + } + + // Parse server address + serverSplit := strings.Split(serverAddr, ":") + if len(serverSplit) < 2 { + return fmt.Errorf("invalid server address format, expected hostname:port") + } + + serverHostname := serverSplit[0] + serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) + if err != nil { + return fmt.Errorf("failed to parse server port: %v", err) + } + + // Resolve server hostname to IP + serverIPAddr := network.HostToAddr(serverHostname) + if serverIPAddr == nil { + return fmt.Errorf("failed to resolve server hostname") + } + + // Get client IP based on route to server + clientIP := network.GetClientIP(serverIPAddr.IP) + + // Create server and client configs + server := &network.Server{ + Hostname: serverHostname, + Addr: serverIPAddr, + Port: uint16(serverPort), + } + + client := &network.PeerNet{ + IP: clientIP, + Port: s.Port, + NewtID: s.newtId, + } + + // Setup raw connection with BPF filtering + rawConn := network.SetupRawConn(server, client) + defer rawConn.Close() + + // Create JSON payload + payload := struct { + NewtID string `json:"newtId"` + Token string `json:"token"` + }{ + NewtID: s.newtId, + Token: s.token, + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := s.encryptPayload(payloadBytes) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %v", err) + } + + // Send the encrypted packet using the raw connection + err = network.SendDataPacket(encryptedPayload, rawConn, server, client) + if err != nil { + return fmt.Errorf("failed to send UDP packet: %v", err) + } + + return nil +} + +func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(s.serverPubKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange (replacing deprecated ScalarMult) + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} + +func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + // send initial hole punch + if err := s.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send initial UDP hole punch: %v", err) + } + + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-ticker.C: + if err := s.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +}