diff --git a/main.go b/main.go index 8eadf5a..faa80e6 100644 --- a/main.go +++ b/main.go @@ -579,7 +579,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( if wgService != nil { // add a udp proxy for localost and the wgService port // TODO: make sure this port is not used in a target - pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("localhost:%d", wgService.Port)) + pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) } err = pm.Start() @@ -705,8 +705,21 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh - // Cleanup dev.Close() + + if wgService != nil { + wgService.Close() + } + + if pm != nil { + pm.Stop() + } + + if client != nil { + client.Close() + } + logger.Info("Exiting...") + os.Exit(0) } func parseTargetData(data interface{}) (TargetData, error) { diff --git a/wg/wg.go b/wg/wg.go index 4322756..b879c9c 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -176,6 +176,10 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str func (s *WireGuardService) Close() { s.wgClient.Close() + // Remove the WireGuard interface + if err := s.removeInterface(); err != nil { + logger.Error("Failed to remove WireGuard interface: %v", err) + } } func (s *WireGuardService) SetServerPubKey(serverPubKey string) { @@ -188,8 +192,16 @@ func (s *WireGuardService) SetToken(token string) { func (s *WireGuardService) LoadRemoteConfig() error { - err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ + // get the exising wireguard port + device, err := s.wgClient.Device(s.interfaceName) + if err == nil { + s.Port = uint16(device.ListenPort) + logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port) + } + + err = s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + "port": s.Port, }) if err != nil { logger.Error("Failed to send registration message: %v", err) @@ -885,3 +897,20 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) { } } } + +func (s *WireGuardService) removeInterface() error { + // Remove the WireGuard interface + link, err := netlink.LinkByName(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + err = netlink.LinkDel(link) + if err != nil { + return fmt.Errorf("failed to delete interface: %v", err) + } + + logger.Info("WireGuard interface %s removed successfully", s.interfaceName) + + return nil +}