Get wg working

This commit is contained in:
Owen
2025-02-21 16:12:12 -05:00
parent 56e75902e3
commit 95eab504fa
2 changed files with 56 additions and 63 deletions

29
main.go
View File

@@ -291,7 +291,7 @@ func main() {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
@@ -335,15 +335,7 @@ func main() {
logger.Fatal("Failed to create client: %v", err)
}
if reachableAt != "" {
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
}
var wgService *wg.WireGuardService
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
@@ -352,6 +344,16 @@ func main() {
var connected bool
var wgData WgData
if reachableAt != "" {
logger.Info("Sending reachableAt to server: %s", reachableAt)
// Create WireGuard service
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
}
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
if pm != nil {
@@ -419,7 +421,7 @@ func main() {
public_key=%s
allowed_ip=%s/32
endpoint=%s
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
err = dev.IpcSet(config)
if err != nil {
@@ -439,6 +441,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
fmt.Sprintf("%s", privateKey)
}
if !connected {
@@ -551,13 +554,15 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.PublicKey(),
"publicKey": fmt.Sprintf("%s", publicKey),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
wgService.LoadRemoteConfig()
logger.Info("Sent registration message")
return nil
})

View File

@@ -1,7 +1,6 @@
package wg
import (
"bytes"
"encoding/json"
"fmt"
"net"
@@ -16,13 +15,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
interfaceName string
mtuInt int
lastReadings = make(map[string]PeerReading)
mu sync.Mutex
)
type WgConfig struct {
PrivateKey string `json:"privateKey"`
ListenPort int `json:"listenPort"`
@@ -47,10 +39,6 @@ type PeerReading struct {
LastChecked time.Time
}
var (
wgClient *wgctrl.Client
)
type WireGuardService struct {
interfaceName string
mtu int
@@ -60,6 +48,7 @@ type WireGuardService struct {
key wgtypes.Key
reachableAt string
lastReadings map[string]PeerReading
mu sync.Mutex
}
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
@@ -107,27 +96,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.loadRemoteConfig)
return service, nil
}
func (s *WireGuardService) Close() {
s.client.Close()
wgClient.Close()
s.wgClient.Close()
}
func (s *WireGuardService) loadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt)))
func (s *WireGuardService) LoadRemoteConfig() error {
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
"endpoint": s.reachableAt,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Requesting WireGuard configuration from remote server")
go s.periodicBandwidthCheck()
err := s.client.SendMessage("newt/wg/get-config", body)
if err != nil {
return fmt.Errorf("failed to send config request: %v", err)
}
return nil
}
@@ -157,7 +148,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName)
_, err := netlink.LinkByName(s.interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
@@ -165,12 +156,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err)
}
logger.Info("Created WireGuard interface %s\n", interfaceName)
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
} else {
logger.Fatal("Error checking for WireGuard interface: %v", err)
}
} else {
logger.Info("WireGuard interface %s already exists\n", interfaceName)
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
return nil
}
@@ -179,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
if err != nil {
logger.Fatal("Failed to assign IP address: %v", err)
}
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
// Check if the interface already exists
_, err = wgClient.Device(interfaceName)
_, err = s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("interface %s does not exist", interfaceName)
return fmt.Errorf("interface %s does not exist", s.interfaceName)
}
// Parse the private key
@@ -201,18 +192,18 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
*config.ListenPort = wgconfig.ListenPort
// Create and configure the WireGuard interface
err = wgClient.ConfigureDevice(interfaceName, config)
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err)
}
// bring up the interface
link, err := netlink.LinkByName(interfaceName)
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
return fmt.Errorf("failed to set MTU: %v", err)
}
@@ -224,21 +215,21 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// logger.Warn("Failed to ensure MSS clamping: %v", err)
// }
logger.Info("WireGuard interface %s created and configured", interfaceName)
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
return nil
}
func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
LinkType: "wireguard",
}
return netlink.LinkAdd(wgLink)
}
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName)
link, err := netlink.LinkByName(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
@@ -253,7 +244,7 @@ func (s *WireGuardService) assignIPAddress(ipAddress string) error {
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers
device, err := wgClient.Device(interfaceName)
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
@@ -432,7 +423,7 @@ func (s *WireGuardService) addPeer(peer Peer) error {
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
@@ -479,7 +470,7 @@ func (s *WireGuardService) removePeer(publicKey string) error {
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
@@ -500,7 +491,7 @@ func (s *WireGuardService) periodicBandwidthCheck() {
}
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := wgClient.Device(interfaceName)
device, err := s.wgClient.Device(s.interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
@@ -508,8 +499,8 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
peerBandwidths := []PeerBandwidth{}
now := time.Now()
mu.Lock()
defer mu.Unlock()
s.mu.Lock()
defer s.mu.Unlock()
for _, peer := range device.Peers {
publicKey := peer.PublicKey.String()
@@ -520,7 +511,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
}
var bytesInDiff, bytesOutDiff float64
lastReading, exists := lastReadings[publicKey]
lastReading, exists := s.lastReadings[publicKey]
if exists {
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
@@ -564,11 +555,11 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
}
// Update the last reading
lastReadings[publicKey] = currentReading
s.lastReadings[publicKey] = currentReading
}
// Clean up old peers
for publicKey := range lastReadings {
for publicKey := range s.lastReadings {
found := false
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
@@ -577,7 +568,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
}
}
if !found {
delete(lastReadings, publicKey)
delete(s.lastReadings, publicKey)
}
}
@@ -590,12 +581,9 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
jsonData, err := json.Marshal(bandwidths)
if err != nil {
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
}
err = s.client.SendMessage("wg/bandwidth", jsonData)
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
"bandwidthData": bandwidths,
})
if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err)
}