mirror of
https://github.com/fosrl/newt.git
synced 2026-03-04 17:56:40 +00:00
Get wg working
This commit is contained in:
29
main.go
29
main.go
@@ -291,7 +291,7 @@ func main() {
|
|||||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
}
|
}
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
|
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||||
}
|
}
|
||||||
if generateAndSaveKeyTo == "" {
|
if generateAndSaveKeyTo == "" {
|
||||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||||
@@ -335,15 +335,7 @@ func main() {
|
|||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reachableAt != "" {
|
var wgService *wg.WireGuardService
|
||||||
// Create WireGuard service
|
|
||||||
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
|
||||||
}
|
|
||||||
defer wgService.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create TUN device and network stack
|
// Create TUN device and network stack
|
||||||
var tun tun.Device
|
var tun tun.Device
|
||||||
var tnet *netstack.Net
|
var tnet *netstack.Net
|
||||||
@@ -352,6 +344,16 @@ func main() {
|
|||||||
var connected bool
|
var connected bool
|
||||||
var wgData WgData
|
var wgData WgData
|
||||||
|
|
||||||
|
if reachableAt != "" {
|
||||||
|
logger.Info("Sending reachableAt to server: %s", reachableAt)
|
||||||
|
// Create WireGuard service
|
||||||
|
wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||||
|
}
|
||||||
|
defer wgService.Close()
|
||||||
|
}
|
||||||
|
|
||||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received terminate message")
|
logger.Info("Received terminate message")
|
||||||
if pm != nil {
|
if pm != nil {
|
||||||
@@ -419,7 +421,7 @@ func main() {
|
|||||||
public_key=%s
|
public_key=%s
|
||||||
allowed_ip=%s/32
|
allowed_ip=%s/32
|
||||||
endpoint=%s
|
endpoint=%s
|
||||||
persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint)
|
||||||
|
|
||||||
err = dev.IpcSet(config)
|
err = dev.IpcSet(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -439,6 +441,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// Handle complete failure after all retries
|
// Handle complete failure after all retries
|
||||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||||
|
fmt.Sprintf("%s", privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !connected {
|
if !connected {
|
||||||
@@ -551,13 +554,15 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
logger.Debug("Public key: %s", publicKey)
|
logger.Debug("Public key: %s", publicKey)
|
||||||
|
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.PublicKey(),
|
"publicKey": fmt.Sprintf("%s", publicKey),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wgService.LoadRemoteConfig()
|
||||||
|
|
||||||
logger.Info("Sent registration message")
|
logger.Info("Sent registration message")
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|||||||
90
wg/wg.go
90
wg/wg.go
@@ -1,7 +1,6 @@
|
|||||||
package wg
|
package wg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -16,13 +15,6 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
interfaceName string
|
|
||||||
mtuInt int
|
|
||||||
lastReadings = make(map[string]PeerReading)
|
|
||||||
mu sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
PrivateKey string `json:"privateKey"`
|
PrivateKey string `json:"privateKey"`
|
||||||
ListenPort int `json:"listenPort"`
|
ListenPort int `json:"listenPort"`
|
||||||
@@ -47,10 +39,6 @@ type PeerReading struct {
|
|||||||
LastChecked time.Time
|
LastChecked time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
wgClient *wgctrl.Client
|
|
||||||
)
|
|
||||||
|
|
||||||
type WireGuardService struct {
|
type WireGuardService struct {
|
||||||
interfaceName string
|
interfaceName string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -60,6 +48,7 @@ type WireGuardService struct {
|
|||||||
key wgtypes.Key
|
key wgtypes.Key
|
||||||
reachableAt string
|
reachableAt string
|
||||||
lastReadings map[string]PeerReading
|
lastReadings map[string]PeerReading
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||||
@@ -107,27 +96,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
|
|||||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||||
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer)
|
||||||
|
|
||||||
// Register connect handler to initiate configuration
|
|
||||||
wsClient.OnConnect(service.loadRemoteConfig)
|
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) Close() {
|
func (s *WireGuardService) Close() {
|
||||||
s.client.Close()
|
s.client.Close()
|
||||||
wgClient.Close()
|
s.wgClient.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) loadRemoteConfig() error {
|
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt)))
|
|
||||||
|
err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{
|
||||||
|
"publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()),
|
||||||
|
"endpoint": s.reachableAt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Requesting WireGuard configuration from remote server")
|
||||||
|
|
||||||
go s.periodicBandwidthCheck()
|
go s.periodicBandwidthCheck()
|
||||||
|
|
||||||
err := s.client.SendMessage("newt/wg/get-config", body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send config request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +148,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||||
// Check if the WireGuard interface exists
|
// Check if the WireGuard interface exists
|
||||||
_, err := netlink.LinkByName(interfaceName)
|
_, err := netlink.LinkByName(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||||
// Interface doesn't exist, so create it
|
// Interface doesn't exist, so create it
|
||||||
@@ -165,12 +156,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||||
}
|
}
|
||||||
logger.Info("Created WireGuard interface %s\n", interfaceName)
|
logger.Info("Created WireGuard interface %s\n", s.interfaceName)
|
||||||
} else {
|
} else {
|
||||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Info("WireGuard interface %s already exists\n", interfaceName)
|
logger.Info("WireGuard interface %s already exists\n", s.interfaceName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to assign IP address: %v", err)
|
logger.Fatal("Failed to assign IP address: %v", err)
|
||||||
}
|
}
|
||||||
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
|
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName)
|
||||||
|
|
||||||
// Check if the interface already exists
|
// Check if the interface already exists
|
||||||
_, err = wgClient.Device(interfaceName)
|
_, err = s.wgClient.Device(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interface %s does not exist", interfaceName)
|
return fmt.Errorf("interface %s does not exist", s.interfaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the private key
|
// Parse the private key
|
||||||
@@ -201,18 +192,18 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
*config.ListenPort = wgconfig.ListenPort
|
*config.ListenPort = wgconfig.ListenPort
|
||||||
|
|
||||||
// Create and configure the WireGuard interface
|
// Create and configure the WireGuard interface
|
||||||
err = wgClient.ConfigureDevice(interfaceName, config)
|
err = s.wgClient.ConfigureDevice(s.interfaceName, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// bring up the interface
|
// bring up the interface
|
||||||
link, err := netlink.LinkByName(interfaceName)
|
link, err := netlink.LinkByName(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get interface: %v", err)
|
return fmt.Errorf("failed to get interface: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
|
if err := netlink.LinkSetMTU(link, s.mtu); err != nil {
|
||||||
return fmt.Errorf("failed to set MTU: %v", err)
|
return fmt.Errorf("failed to set MTU: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,21 +215,21 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||||
// }
|
// }
|
||||||
|
|
||||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
logger.Info("WireGuard interface %s created and configured", s.interfaceName)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) createWireGuardInterface() error {
|
func (s *WireGuardService) createWireGuardInterface() error {
|
||||||
wgLink := &netlink.GenericLink{
|
wgLink := &netlink.GenericLink{
|
||||||
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName},
|
||||||
LinkType: "wireguard",
|
LinkType: "wireguard",
|
||||||
}
|
}
|
||||||
return netlink.LinkAdd(wgLink)
|
return netlink.LinkAdd(wgLink)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||||
link, err := netlink.LinkByName(interfaceName)
|
link, err := netlink.LinkByName(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get interface: %v", err)
|
return fmt.Errorf("failed to get interface: %v", err)
|
||||||
}
|
}
|
||||||
@@ -253,7 +244,7 @@ func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
|||||||
|
|
||||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||||
// get the current peers
|
// get the current peers
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := s.wgClient.Device(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get device: %v", err)
|
return fmt.Errorf("failed to get device: %v", err)
|
||||||
}
|
}
|
||||||
@@ -432,7 +423,7 @@ func (s *WireGuardService) addPeer(peer Peer) error {
|
|||||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||||
return fmt.Errorf("failed to add peer: %v", err)
|
return fmt.Errorf("failed to add peer: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +470,7 @@ func (s *WireGuardService) removePeer(publicKey string) error {
|
|||||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil {
|
||||||
return fmt.Errorf("failed to remove peer: %v", err)
|
return fmt.Errorf("failed to remove peer: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -500,7 +491,7 @@ func (s *WireGuardService) periodicBandwidthCheck() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := s.wgClient.Device(s.interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||||
}
|
}
|
||||||
@@ -508,8 +499,8 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
|||||||
peerBandwidths := []PeerBandwidth{}
|
peerBandwidths := []PeerBandwidth{}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
mu.Lock()
|
s.mu.Lock()
|
||||||
defer mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range device.Peers {
|
||||||
publicKey := peer.PublicKey.String()
|
publicKey := peer.PublicKey.String()
|
||||||
@@ -520,7 +511,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var bytesInDiff, bytesOutDiff float64
|
var bytesInDiff, bytesOutDiff float64
|
||||||
lastReading, exists := lastReadings[publicKey]
|
lastReading, exists := s.lastReadings[publicKey]
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
||||||
@@ -564,11 +555,11 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the last reading
|
// Update the last reading
|
||||||
lastReadings[publicKey] = currentReading
|
s.lastReadings[publicKey] = currentReading
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up old peers
|
// Clean up old peers
|
||||||
for publicKey := range lastReadings {
|
for publicKey := range s.lastReadings {
|
||||||
found := false
|
found := false
|
||||||
for _, peer := range device.Peers {
|
for _, peer := range device.Peers {
|
||||||
if peer.PublicKey.String() == publicKey {
|
if peer.PublicKey.String() == publicKey {
|
||||||
@@ -577,7 +568,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
delete(lastReadings, publicKey)
|
delete(s.lastReadings, publicKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -590,12 +581,9 @@ func (s *WireGuardService) reportPeerBandwidth() error {
|
|||||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jsonData, err := json.Marshal(bandwidths)
|
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{
|
||||||
if err != nil {
|
"bandwidthData": bandwidths,
|
||||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
err = s.client.SendMessage("wg/bandwidth", jsonData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user