diff --git a/main.go b/main.go index fdf5baf..f8ea796 100644 --- a/main.go +++ b/main.go @@ -104,10 +104,10 @@ func main() { dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") updownScript = os.Getenv("UPDOWN_SCRIPT") - // interfaceName = os.Getenv("INTERFACE") - // generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") - // rm = os.Getenv("RM") == "true" - // acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true" + interfaceName = os.Getenv("INTERFACE") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + rm = os.Getenv("RM") == "true" + acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true" tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") pingIntervalStr := os.Getenv("PING_INTERVAL") @@ -136,14 +136,14 @@ func main() { if updownScript == "" { flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed") } - // if interfaceName == "" { - // flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface") - // } - // if generateAndSaveKeyTo == "" { - // flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key") - // } - // flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface") - // flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface") + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key") + } + flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface") + flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface") if tlsPrivateKey == "" { flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS") } diff --git a/wg/wg.go b/wg/wg.go index 1c378ea..a69c000 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -4,6 +4,7 @@ package wg import ( "encoding/json" + "errors" "fmt" "net" "os" @@ -62,7 +63,7 @@ type WireGuardService struct { host string serverPubKey string token string - stopGetConfig chan struct{} + stopGetConfig func() } // Add this type definition @@ -181,14 +182,21 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str host: host, lastReadings: make(map[string]PeerReading), stopHolepunch: make(chan struct{}), - stopGetConfig: make(chan struct{}), } // Get the existing wireguard port (keep this part) device, err := service.wgClient.Device(service.interfaceName) if err == nil { service.Port = uint16(device.ListenPort) - logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port) + if service.Port != 0 { + logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port) + } else { + service.Port, err = FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + return nil, err + } + } } else { service.Port, err = FindAvailableUDPPort(49152, 65535) if err != nil { @@ -214,11 +222,9 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str } func (s *WireGuardService) Close(rm bool) { - select { - case <-s.stopGetConfig: - // Already closed, do nothing - default: - close(s.stopGetConfig) + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil } s.wgClient.Close() @@ -244,16 +250,12 @@ func (s *WireGuardService) SetToken(token string) { } func (s *WireGuardService) LoadRemoteConfig() error { - // Send the initial message - err := s.sendGetConfigMessage() - if err != nil { - logger.Error("Failed to send initial get-config message: %v", err) - return err - } - - // Start goroutine to periodically send the message until config is received - go s.keepSendingGetConfig() + s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + "port": s.Port, + }, 2*time.Second) + logger.Info("Requesting WireGuard configuration from remote server") go s.periodicBandwidthCheck() return nil @@ -276,7 +278,10 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { } s.config = config - close(s.stopGetConfig) + if s.stopGetConfig != nil { + s.stopGetConfig() + s.stopGetConfig = nil + } // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { @@ -328,7 +333,10 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Check if the interface already exists _, err = s.wgClient.Device(s.interfaceName) if err != nil { - return fmt.Errorf("interface %s does not exist", s.interfaceName) + if errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("interface %s does not exist", s.interfaceName) + } + return fmt.Errorf("failed to get device: %v", err) } // Parse the private key @@ -949,33 +957,3 @@ func (s *WireGuardService) removeInterface() error { return nil } - -func (s *WireGuardService) sendGetConfigMessage() error { - err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ - "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), - "port": s.Port, - }) - if err != nil { - logger.Error("Failed to send get-config message: %v", err) - return err - } - logger.Info("Requesting WireGuard configuration from remote server") - return nil -} - -func (s *WireGuardService) keepSendingGetConfig() { - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - - for { - select { - case <-s.stopGetConfig: - logger.Info("Stopping get-config messages") - return - case <-ticker.C: - if err := s.sendGetConfigMessage(); err != nil { - logger.Error("Failed to send periodic get-config: %v", err) - } - } - } -}