From 95eab504fac0c87077d75152c03c149c3fc23efd Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 16:12:12 -0500 Subject: [PATCH] Get wg working --- main.go | 29 ++++++++++-------- wg/wg.go | 90 ++++++++++++++++++++++++-------------------------------- 2 files changed, 56 insertions(+), 63 deletions(-) diff --git a/main.go b/main.go index 0592a51..1f0b289 100644 --- a/main.go +++ b/main.go @@ -291,7 +291,7 @@ func main() { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface") + flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") @@ -335,15 +335,7 @@ func main() { logger.Fatal("Failed to create client: %v", err) } - if reachableAt != "" { - // Create WireGuard service - wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) - if err != nil { - logger.Fatal("Failed to create WireGuard service: %v", err) - } - defer wgService.Close() - } - + var wgService *wg.WireGuardService // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net @@ -352,6 +344,16 @@ func main() { var connected bool var wgData WgData + if reachableAt != "" { + logger.Info("Sending reachableAt to server: %s", reachableAt) + // Create WireGuard service + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + defer wgService.Close() + } + client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") if pm != nil { @@ -419,7 +421,7 @@ func main() { public_key=%s allowed_ip=%s/32 endpoint=%s -persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) +persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) err = dev.IpcSet(config) if err != nil { @@ -439,6 +441,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub if err != nil { // Handle complete failure after all retries logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + fmt.Sprintf("%s", privateKey) } if !connected { @@ -551,13 +554,15 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Debug("Public key: %s", publicKey) err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": publicKey.PublicKey(), + "publicKey": fmt.Sprintf("%s", publicKey), }) if err != nil { logger.Error("Failed to send registration message: %v", err) return err } + wgService.LoadRemoteConfig() + logger.Info("Sent registration message") return nil }) diff --git a/wg/wg.go b/wg/wg.go index 7e19958..6df69a7 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -1,7 +1,6 @@ package wg import ( - "bytes" "encoding/json" "fmt" "net" @@ -16,13 +15,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - interfaceName string - mtuInt int - lastReadings = make(map[string]PeerReading) - mu sync.Mutex -) - type WgConfig struct { PrivateKey string `json:"privateKey"` ListenPort int `json:"listenPort"` @@ -47,10 +39,6 @@ type PeerReading struct { LastChecked time.Time } -var ( - wgClient *wgctrl.Client -) - type WireGuardService struct { interfaceName string mtu int @@ -60,6 +48,7 @@ type WireGuardService struct { key wgtypes.Key reachableAt string lastReadings map[string]PeerReading + mu sync.Mutex } func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { @@ -107,27 +96,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) - // Register connect handler to initiate configuration - wsClient.OnConnect(service.loadRemoteConfig) - return service, nil } func (s *WireGuardService) Close() { s.client.Close() - wgClient.Close() + s.wgClient.Close() } -func (s *WireGuardService) loadRemoteConfig() error { - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt))) +func (s *WireGuardService) LoadRemoteConfig() error { + + err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + "endpoint": s.reachableAt, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Requesting WireGuard configuration from remote server") go s.periodicBandwidthCheck() - err := s.client.SendMessage("newt/wg/get-config", body) - if err != nil { - return fmt.Errorf("failed to send config request: %v", err) - } - return nil } @@ -157,7 +148,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Check if the WireGuard interface exists - _, err := netlink.LinkByName(interfaceName) + _, err := netlink.LinkByName(s.interfaceName) if err != nil { if _, ok := err.(netlink.LinkNotFoundError); ok { // Interface doesn't exist, so create it @@ -165,12 +156,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { if err != nil { logger.Fatal("Failed to create WireGuard interface: %v", err) } - logger.Info("Created WireGuard interface %s\n", interfaceName) + logger.Info("Created WireGuard interface %s\n", s.interfaceName) } else { logger.Fatal("Error checking for WireGuard interface: %v", err) } } else { - logger.Info("WireGuard interface %s already exists\n", interfaceName) + logger.Info("WireGuard interface %s already exists\n", s.interfaceName) return nil } @@ -179,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { if err != nil { logger.Fatal("Failed to assign IP address: %v", err) } - logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) + logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) // Check if the interface already exists - _, err = wgClient.Device(interfaceName) + _, err = s.wgClient.Device(s.interfaceName) if err != nil { - return fmt.Errorf("interface %s does not exist", interfaceName) + return fmt.Errorf("interface %s does not exist", s.interfaceName) } // Parse the private key @@ -201,18 +192,18 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { *config.ListenPort = wgconfig.ListenPort // Create and configure the WireGuard interface - err = wgClient.ConfigureDevice(interfaceName, config) + err = s.wgClient.ConfigureDevice(s.interfaceName, config) if err != nil { return fmt.Errorf("failed to configure WireGuard device: %v", err) } // bring up the interface - link, err := netlink.LinkByName(interfaceName) + link, err := netlink.LinkByName(s.interfaceName) if err != nil { return fmt.Errorf("failed to get interface: %v", err) } - if err := netlink.LinkSetMTU(link, mtuInt); err != nil { + if err := netlink.LinkSetMTU(link, s.mtu); err != nil { return fmt.Errorf("failed to set MTU: %v", err) } @@ -224,21 +215,21 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // logger.Warn("Failed to ensure MSS clamping: %v", err) // } - logger.Info("WireGuard interface %s created and configured", interfaceName) + logger.Info("WireGuard interface %s created and configured", s.interfaceName) return nil } func (s *WireGuardService) createWireGuardInterface() error { wgLink := &netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, + LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName}, LinkType: "wireguard", } return netlink.LinkAdd(wgLink) } func (s *WireGuardService) assignIPAddress(ipAddress string) error { - link, err := netlink.LinkByName(interfaceName) + link, err := netlink.LinkByName(s.interfaceName) if err != nil { return fmt.Errorf("failed to get interface: %v", err) } @@ -253,7 +244,7 @@ func (s *WireGuardService) assignIPAddress(ipAddress string) error { func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { // get the current peers - device, err := wgClient.Device(interfaceName) + device, err := s.wgClient.Device(s.interfaceName) if err != nil { return fmt.Errorf("failed to get device: %v", err) } @@ -432,7 +423,7 @@ func (s *WireGuardService) addPeer(peer Peer) error { Peers: []wgtypes.PeerConfig{peerConfig}, } - if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { return fmt.Errorf("failed to add peer: %v", err) } @@ -479,7 +470,7 @@ func (s *WireGuardService) removePeer(publicKey string) error { Peers: []wgtypes.PeerConfig{peerConfig}, } - if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { return fmt.Errorf("failed to remove peer: %v", err) } @@ -500,7 +491,7 @@ func (s *WireGuardService) periodicBandwidthCheck() { } func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { - device, err := wgClient.Device(interfaceName) + device, err := s.wgClient.Device(s.interfaceName) if err != nil { return nil, fmt.Errorf("failed to get device: %v", err) } @@ -508,8 +499,8 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { peerBandwidths := []PeerBandwidth{} now := time.Now() - mu.Lock() - defer mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() for _, peer := range device.Peers { publicKey := peer.PublicKey.String() @@ -520,7 +511,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { } var bytesInDiff, bytesOutDiff float64 - lastReading, exists := lastReadings[publicKey] + lastReading, exists := s.lastReadings[publicKey] if exists { timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() @@ -564,11 +555,11 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { } // Update the last reading - lastReadings[publicKey] = currentReading + s.lastReadings[publicKey] = currentReading } // Clean up old peers - for publicKey := range lastReadings { + for publicKey := range s.lastReadings { found := false for _, peer := range device.Peers { if peer.PublicKey.String() == publicKey { @@ -577,7 +568,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { } } if !found { - delete(lastReadings, publicKey) + delete(s.lastReadings, publicKey) } } @@ -590,12 +581,9 @@ func (s *WireGuardService) reportPeerBandwidth() error { return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } - jsonData, err := json.Marshal(bandwidths) - if err != nil { - return fmt.Errorf("failed to marshal bandwidth data: %v", err) - } - - err = s.client.SendMessage("wg/bandwidth", jsonData) + err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ + "bandwidthData": bandwidths, + }) if err != nil { return fmt.Errorf("failed to send bandwidth data: %v", err) }