mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Get wg working
This commit is contained in:
29
main.go
29
main.go
@@ -291,7 +291,7 @@ func main() {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
|
||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
@@ -335,15 +335,7 @@ func main() {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
if reachableAt != "" {
|
||||
// Create WireGuard service
|
||||
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
defer wgService.Close()
|
||||
}
|
||||
|
||||
var wgService *wg.WireGuardService
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
@@ -352,6 +344,16 @@ func main() {
|
||||
var connected bool
|
||||
var wgData WgData
|
||||
|
||||
if reachableAt != "" {
|
||||
logger.Info("Sending reachableAt to server: %s", reachableAt)
|
||||
// Create WireGuard service
|
||||
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
defer wgService.Close()
|
||||
}
|
||||
|
||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
if pm != nil {
|
||||
@@ -419,7 +421,7 @@ func main() {
|
||||
public_key=%s
|
||||
allowed_ip=%s/32
|
||||
endpoint=%s
|
||||
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||
persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||
|
||||
err = dev.IpcSet(config)
|
||||
if err != nil {
|
||||
@@ -439,6 +441,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
fmt.Sprintf("%s", privateKey)
|
||||
}
|
||||
|
||||
if !connected {
|
||||
@@ -551,13 +554,15 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
logger.Debug("Public key: %s", publicKey)
|
||||
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.PublicKey(),
|
||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
wgService.LoadRemoteConfig()
|
||||
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
})
|
||||
|
||||
90
wg/wg.go
90
wg/wg.go
@@ -1,7 +1,6 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -16,13 +15,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var (
|
||||
interfaceName string
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
ListenPort int `json:"listenPort"`
|
||||
@@ -47,10 +39,6 @@ type PeerReading struct {
|
||||
LastChecked time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
wgClient *wgctrl.Client
|
||||
)
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
@@ -60,6 +48,7 @@ type WireGuardService struct {
|
||||
key wgtypes.Key
|
||||
reachableAt string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||
@@ -107,27 +96,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
|
||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||
|
||||
// Register connect handler to initiate configuration
|
||||
wsClient.OnConnect(service.loadRemoteConfig)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Close() {
|
||||
s.client.Close()
|
||||
wgClient.Close()
|
||||
s.wgClient.Close()
|
||||
}
|
||||
|
||||
func (s *WireGuardService) loadRemoteConfig() error {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt)))
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
|
||||
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
|
||||
"endpoint": s.reachableAt,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Requesting WireGuard configuration from remote server")
|
||||
|
||||
go s.periodicBandwidthCheck()
|
||||
|
||||
err := s.client.SendMessage("newt/wg/get-config", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send config request: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,7 +148,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Check if the WireGuard interface exists
|
||||
_, err := netlink.LinkByName(interfaceName)
|
||||
_, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||
// Interface doesn't exist, so create it
|
||||
@@ -165,12 +156,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
logger.Info("Created WireGuard interface %s\n", interfaceName)
|
||||
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
|
||||
} else {
|
||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("WireGuard interface %s already exists\n", interfaceName)
|
||||
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -179,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to assign IP address: %v", err)
|
||||
}
|
||||
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
|
||||
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
|
||||
|
||||
// Check if the interface already exists
|
||||
_, err = wgClient.Device(interfaceName)
|
||||
_, err = s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface %s does not exist", interfaceName)
|
||||
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
||||
}
|
||||
|
||||
// Parse the private key
|
||||
@@ -201,18 +192,18 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
*config.ListenPort = wgconfig.ListenPort
|
||||
|
||||
// Create and configure the WireGuard interface
|
||||
err = wgClient.ConfigureDevice(interfaceName, config)
|
||||
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// bring up the interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
|
||||
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
|
||||
return fmt.Errorf("failed to set MTU: %v", err)
|
||||
}
|
||||
|
||||
@@ -224,21 +215,21 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
// }
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) createWireGuardInterface() error {
|
||||
wgLink := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
||||
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
return netlink.LinkAdd(wgLink)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
link, err := netlink.LinkByName(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
@@ -253,7 +244,7 @@ func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
|
||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
// get the current peers
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
@@ -432,7 +423,7 @@ func (s *WireGuardService) addPeer(peer Peer) error {
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
@@ -479,7 +470,7 @@ func (s *WireGuardService) removePeer(publicKey string) error {
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
||||
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
@@ -500,7 +491,7 @@ func (s *WireGuardService) periodicBandwidthCheck() {
|
||||
}
|
||||
|
||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
device, err := s.wgClient.Device(s.interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
@@ -508,8 +499,8 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
peerBandwidths := []PeerBandwidth{}
|
||||
now := time.Now()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
@@ -520,7 +511,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
}
|
||||
|
||||
var bytesInDiff, bytesOutDiff float64
|
||||
lastReading, exists := lastReadings[publicKey]
|
||||
lastReading, exists := s.lastReadings[publicKey]
|
||||
|
||||
if exists {
|
||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
||||
@@ -564,11 +555,11 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
}
|
||||
|
||||
// Update the last reading
|
||||
lastReadings[publicKey] = currentReading
|
||||
s.lastReadings[publicKey] = currentReading
|
||||
}
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range lastReadings {
|
||||
for publicKey := range s.lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
@@ -577,7 +568,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
delete(lastReadings, publicKey)
|
||||
delete(s.lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -590,12 +581,9 @@ func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bandwidths)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
err = s.client.SendMessage("wg/bandwidth", jsonData)
|
||||
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
|
||||
"bandwidthData": bandwidths,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user