mirror of
https://github.com/fosrl/newt.git
synced 2026-03-03 09:16:44 +00:00
Working on it
This commit is contained in:
24
main.go
24
main.go
@@ -104,10 +104,10 @@ func main() {
|
|||||||
dns = os.Getenv("DNS")
|
dns = os.Getenv("DNS")
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
logLevel = os.Getenv("LOG_LEVEL")
|
||||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||||
// interfaceName = os.Getenv("INTERFACE")
|
interfaceName = os.Getenv("INTERFACE")
|
||||||
// generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||||
// rm = os.Getenv("RM") == "true"
|
rm = os.Getenv("RM") == "true"
|
||||||
// acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
|
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
|
||||||
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||||
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
||||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||||
@@ -136,14 +136,14 @@ func main() {
|
|||||||
if updownScript == "" {
|
if updownScript == "" {
|
||||||
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
flag.StringVar(&updownScript, "updown", "", "Path to updown script to be called when targets are added or removed")
|
||||||
}
|
}
|
||||||
// if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
// flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface")
|
flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface")
|
||||||
// }
|
}
|
||||||
// if generateAndSaveKeyTo == "" {
|
if generateAndSaveKeyTo == "" {
|
||||||
// flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key")
|
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key")
|
||||||
// }
|
}
|
||||||
// flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface")
|
flag.BoolVar(&rm, "rm", false, "Remove the WireGuard interface")
|
||||||
// flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
|
flag.BoolVar(&acceptClients, "accept-clients", false, "Accept clients on the WireGuard interface")
|
||||||
if tlsPrivateKey == "" {
|
if tlsPrivateKey == "" {
|
||||||
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
|
flag.StringVar(&tlsPrivateKey, "tls-client-cert", "", "Path to client certificate used for mTLS")
|
||||||
}
|
}
|
||||||
|
|||||||
76
wg/wg.go
76
wg/wg.go
@@ -4,6 +4,7 @@ package wg
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@@ -62,7 +63,7 @@ type WireGuardService struct {
|
|||||||
host string
|
host string
|
||||||
serverPubKey string
|
serverPubKey string
|
||||||
token string
|
token string
|
||||||
stopGetConfig chan struct{}
|
stopGetConfig func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add this type definition
|
// Add this type definition
|
||||||
@@ -181,14 +182,21 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
host: host,
|
host: host,
|
||||||
lastReadings: make(map[string]PeerReading),
|
lastReadings: make(map[string]PeerReading),
|
||||||
stopHolepunch: make(chan struct{}),
|
stopHolepunch: make(chan struct{}),
|
||||||
stopGetConfig: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the existing wireguard port (keep this part)
|
// Get the existing wireguard port (keep this part)
|
||||||
device, err := service.wgClient.Device(service.interfaceName)
|
device, err := service.wgClient.Device(service.interfaceName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
service.Port = uint16(device.ListenPort)
|
service.Port = uint16(device.ListenPort)
|
||||||
logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port)
|
if service.Port != 0 {
|
||||||
|
logger.Info("WireGuard interface %s already exists with port %d\n", service.interfaceName, service.Port)
|
||||||
|
} else {
|
||||||
|
service.Port, err = FindAvailableUDPPort(49152, 65535)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error finding available port: %v\n", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
service.Port, err = FindAvailableUDPPort(49152, 65535)
|
service.Port, err = FindAvailableUDPPort(49152, 65535)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -214,11 +222,9 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) Close(rm bool) {
|
func (s *WireGuardService) Close(rm bool) {
|
||||||
select {
|
if s.stopGetConfig != nil {
|
||||||
case <-s.stopGetConfig:
|
s.stopGetConfig()
|
||||||
// Already closed, do nothing
|
s.stopGetConfig = nil
|
||||||
default:
|
|
||||||
close(s.stopGetConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.wgClient.Close()
|
s.wgClient.Close()
|
||||||
@@ -244,16 +250,12 @@ func (s *WireGuardService) SetToken(token string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||||
// Send the initial message
|
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||||
err := s.sendGetConfigMessage()
|
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
|
||||||
if err != nil {
|
"port": s.Port,
|
||||||
logger.Error("Failed to send initial get-config message: %v", err)
|
}, 2*time.Second)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start goroutine to periodically send the message until config is received
|
|
||||||
go s.keepSendingGetConfig()
|
|
||||||
|
|
||||||
|
logger.Info("Requesting WireGuard configuration from remote server")
|
||||||
go s.periodicBandwidthCheck()
|
go s.periodicBandwidthCheck()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -276,7 +278,10 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
close(s.stopGetConfig)
|
if s.stopGetConfig != nil {
|
||||||
|
s.stopGetConfig()
|
||||||
|
s.stopGetConfig = nil
|
||||||
|
}
|
||||||
|
|
||||||
// Ensure the WireGuard interface and peers are configured
|
// Ensure the WireGuard interface and peers are configured
|
||||||
if err := s.ensureWireguardInterface(config); err != nil {
|
if err := s.ensureWireguardInterface(config); err != nil {
|
||||||
@@ -328,7 +333,10 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
// Check if the interface already exists
|
// Check if the interface already exists
|
||||||
_, err = s.wgClient.Device(s.interfaceName)
|
_, err = s.wgClient.Device(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to get device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the private key
|
// Parse the private key
|
||||||
@@ -949,33 +957,3 @@ func (s *WireGuardService) removeInterface() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) sendGetConfigMessage() error {
|
|
||||||
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
|
|
||||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
|
|
||||||
"port": s.Port,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send get-config message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Info("Requesting WireGuard configuration from remote server")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) keepSendingGetConfig() {
|
|
||||||
ticker := time.NewTicker(3 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-s.stopGetConfig:
|
|
||||||
logger.Info("Stopping get-config messages")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := s.sendGetConfigMessage(); err != nil {
|
|
||||||
logger.Error("Failed to send periodic get-config: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user