Get wg working

This commit is contained in:
Owen
2025-02-21 16:12:12 -05:00
parent 56e75902e3
commit 95eab504fa
2 changed files with 56 additions and 63 deletions

29
main.go
View File

@@ -291,7 +291,7 @@ func main() {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
} }
if interfaceName == "" { if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface") flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
} }
if generateAndSaveKeyTo == "" { if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
@@ -335,15 +335,7 @@ func main() {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
if reachableAt != "" { var wgService *wg.WireGuardService
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
}
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device
var tnet *netstack.Net var tnet *netstack.Net
@@ -352,6 +344,16 @@ func main() {
var connected bool var connected bool
var wgData WgData var wgData WgData
if reachableAt != "" {
logger.Info("Sending reachableAt to server: %s", reachableAt)
// Create WireGuard service
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
}
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message") logger.Info("Received terminate message")
if pm != nil { if pm != nil {
@@ -419,7 +421,7 @@ func main() {
public_key=%s public_key=%s
allowed_ip=%s/32 allowed_ip=%s/32
endpoint=%s endpoint=%s
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
err = dev.IpcSet(config) err = dev.IpcSet(config)
if err != nil { if err != nil {
@@ -439,6 +441,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
if err != nil { if err != nil {
// Handle complete failure after all retries // Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
fmt.Sprintf("%s", privateKey)
} }
if !connected { if !connected {
@@ -551,13 +554,15 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Debug("Public key: %s", publicKey) logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{ err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.PublicKey(), "publicKey": fmt.Sprintf("%s", publicKey),
}) })
if err != nil { if err != nil {
logger.Error("Failed to send registration message: %v", err) logger.Error("Failed to send registration message: %v", err)
return err return err
} }
wgService.LoadRemoteConfig()
logger.Info("Sent registration message") logger.Info("Sent registration message")
return nil return nil
}) })

View File

@@ -1,7 +1,6 @@
package wg package wg
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@@ -16,13 +15,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var (
interfaceName string
mtuInt int
lastReadings = make(map[string]PeerReading)
mu sync.Mutex
)
type WgConfig struct { type WgConfig struct {
PrivateKey string `json:"privateKey"` PrivateKey string `json:"privateKey"`
ListenPort int `json:"listenPort"` ListenPort int `json:"listenPort"`
@@ -47,10 +39,6 @@ type PeerReading struct {
LastChecked time.Time LastChecked time.Time
} }
var (
wgClient *wgctrl.Client
)
type WireGuardService struct { type WireGuardService struct {
interfaceName string interfaceName string
mtu int mtu int
@@ -60,6 +48,7 @@ type WireGuardService struct {
key wgtypes.Key key wgtypes.Key
reachableAt string reachableAt string
lastReadings map[string]PeerReading lastReadings map[string]PeerReading
mu sync.Mutex
} }
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
@@ -107,27 +96,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.loadRemoteConfig)
return service, nil return service, nil
} }
func (s *WireGuardService) Close() { func (s *WireGuardService) Close() {
s.client.Close() s.client.Close()
wgClient.Close() s.wgClient.Close()
} }
func (s *WireGuardService) loadRemoteConfig() error { func (s *WireGuardService) LoadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt)))
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
"endpoint": s.reachableAt,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Requesting WireGuard configuration from remote server")
go s.periodicBandwidthCheck() go s.periodicBandwidthCheck()
err := s.client.SendMessage("newt/wg/get-config", body)
if err != nil {
return fmt.Errorf("failed to send config request: %v", err)
}
return nil return nil
} }
@@ -157,7 +148,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists // Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName) _, err := netlink.LinkByName(s.interfaceName)
if err != nil { if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok { if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it // Interface doesn't exist, so create it
@@ -165,12 +156,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
if err != nil { if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err) logger.Fatal("Failed to create WireGuard interface: %v", err)
} }
logger.Info("Created WireGuard interface %s\n", interfaceName) logger.Info("Created WireGuard interface %s\n", s.interfaceName)
} else { } else {
logger.Fatal("Error checking for WireGuard interface: %v", err) logger.Fatal("Error checking for WireGuard interface: %v", err)
} }
} else { } else {
logger.Info("WireGuard interface %s already exists\n", interfaceName) logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
return nil return nil
} }
@@ -179,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
if err != nil { if err != nil {
logger.Fatal("Failed to assign IP address: %v", err) logger.Fatal("Failed to assign IP address: %v", err)
} }
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
// Check if the interface already exists // Check if the interface already exists
_, err = wgClient.Device(interfaceName) _, err = s.wgClient.Device(s.interfaceName)
if err != nil { if err != nil {
return fmt.Errorf("interface %s does not exist", interfaceName) return fmt.Errorf("interface %s does not exist", s.interfaceName)
} }
// Parse the private key // Parse the private key
@@ -201,18 +192,18 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
*config.ListenPort = wgconfig.ListenPort *config.ListenPort = wgconfig.ListenPort
// Create and configure the WireGuard interface // Create and configure the WireGuard interface
err = wgClient.ConfigureDevice(interfaceName, config) err = s.wgClient.ConfigureDevice(s.interfaceName, config)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err) return fmt.Errorf("failed to configure WireGuard device: %v", err)
} }
// bring up the interface // bring up the interface
link, err := netlink.LinkByName(interfaceName) link, err := netlink.LinkByName(s.interfaceName)
if err != nil { if err != nil {
return fmt.Errorf("failed to get interface: %v", err) return fmt.Errorf("failed to get interface: %v", err)
} }
if err := netlink.LinkSetMTU(link, mtuInt); err != nil { if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
return fmt.Errorf("failed to set MTU: %v", err) return fmt.Errorf("failed to set MTU: %v", err)
} }
@@ -224,21 +215,21 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// logger.Warn("Failed to ensure MSS clamping: %v", err) // logger.Warn("Failed to ensure MSS clamping: %v", err)
// } // }
logger.Info("WireGuard interface %s created and configured", interfaceName) logger.Info("WireGuard interface %s created and configured", s.interfaceName)
return nil return nil
} }
func (s *WireGuardService) createWireGuardInterface() error { func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{ wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
LinkType: "wireguard", LinkType: "wireguard",
} }
return netlink.LinkAdd(wgLink) return netlink.LinkAdd(wgLink)
} }
func (s *WireGuardService) assignIPAddress(ipAddress string) error { func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName) link, err := netlink.LinkByName(s.interfaceName)
if err != nil { if err != nil {
return fmt.Errorf("failed to get interface: %v", err) return fmt.Errorf("failed to get interface: %v", err)
} }
@@ -253,7 +244,7 @@ func (s *WireGuardService) assignIPAddress(ipAddress string) error {
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers // get the current peers
device, err := wgClient.Device(interfaceName) device, err := s.wgClient.Device(s.interfaceName)
if err != nil { if err != nil {
return fmt.Errorf("failed to get device: %v", err) return fmt.Errorf("failed to get device: %v", err)
} }
@@ -432,7 +423,7 @@ func (s *WireGuardService) addPeer(peer Peer) error {
Peers: []wgtypes.PeerConfig{peerConfig}, Peers: []wgtypes.PeerConfig{peerConfig},
} }
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err) return fmt.Errorf("failed to add peer: %v", err)
} }
@@ -479,7 +470,7 @@ func (s *WireGuardService) removePeer(publicKey string) error {
Peers: []wgtypes.PeerConfig{peerConfig}, Peers: []wgtypes.PeerConfig{peerConfig},
} }
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err) return fmt.Errorf("failed to remove peer: %v", err)
} }
@@ -500,7 +491,7 @@ func (s *WireGuardService) periodicBandwidthCheck() {
} }
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := wgClient.Device(interfaceName) device, err := s.wgClient.Device(s.interfaceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err) return nil, fmt.Errorf("failed to get device: %v", err)
} }
@@ -508,8 +499,8 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
peerBandwidths := []PeerBandwidth{} peerBandwidths := []PeerBandwidth{}
now := time.Now() now := time.Now()
mu.Lock() s.mu.Lock()
defer mu.Unlock() defer s.mu.Unlock()
for _, peer := range device.Peers { for _, peer := range device.Peers {
publicKey := peer.PublicKey.String() publicKey := peer.PublicKey.String()
@@ -520,7 +511,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
} }
var bytesInDiff, bytesOutDiff float64 var bytesInDiff, bytesOutDiff float64
lastReading, exists := lastReadings[publicKey] lastReading, exists := s.lastReadings[publicKey]
if exists { if exists {
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
@@ -564,11 +555,11 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
} }
// Update the last reading // Update the last reading
lastReadings[publicKey] = currentReading s.lastReadings[publicKey] = currentReading
} }
// Clean up old peers // Clean up old peers
for publicKey := range lastReadings { for publicKey := range s.lastReadings {
found := false found := false
for _, peer := range device.Peers { for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey { if peer.PublicKey.String() == publicKey {
@@ -577,7 +568,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
} }
} }
if !found { if !found {
delete(lastReadings, publicKey) delete(s.lastReadings, publicKey)
} }
} }
@@ -590,12 +581,9 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err) return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
} }
jsonData, err := json.Marshal(bandwidths) err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
if err != nil { "bandwidthData": bandwidths,
return fmt.Errorf("failed to marshal bandwidth data: %v", err) })
}
err = s.client.SendMessage("wg/bandwidth", jsonData)
if err != nil { if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err) return fmt.Errorf("failed to send bandwidth data: %v", err)
} }