Move wg into more of a class

This commit is contained in:
Owen
2025-02-20 20:37:31 -05:00
parent e8bd55bed9
commit f69a7f647d
2 changed files with 184 additions and 160 deletions

31
main.go
View File

@@ -20,6 +20,7 @@ import (
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy" "github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket" "github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@@ -255,6 +256,9 @@ func main() {
privateKey wgtypes.Key privateKey wgtypes.Key
err error err error
logLevel string logLevel string
interfaceName string
generateAndSaveKeyTo string
reachableAt string
) )
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
@@ -264,6 +268,9 @@ func main() {
mtu = os.Getenv("MTU") mtu = os.Getenv("MTU")
dns = os.Getenv("DNS") dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL") logLevel = os.Getenv("LOG_LEVEL")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -283,6 +290,15 @@ func main() {
if logLevel == "" { if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
} }
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@@ -319,6 +335,21 @@ func main() {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
// defer wgService.Close()
// Start the WireGuard service
if err := wgService.Start(); err != nil {
logger.Fatal("Failed to start WireGuard service: %v", err)
}
// Start bandwidth reporting
wgService.StartBandwidthReporting()
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device
var tnet *netstack.Net var tnet *netstack.Net

253
wg/wg.go
View File

@@ -3,14 +3,10 @@ package wg
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"flag"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"os" "os"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -58,147 +54,149 @@ var (
wgClient *wgctrl.Client wgClient *wgctrl.Client
) )
func main() { type WireGuardService struct {
var ( interfaceName string
err error mtu int
wgconfig WgConfig client *websocket.Client
remoteConfigURL string wgClient *wgctrl.Client
generateAndSaveKeyTo string config WgConfig
key wgtypes.Key
reachableAt string reachableAt string
logLevel string generateAndSaveKeyTo string
mtu string lastReadings map[string]PeerReading
) mu sync.Mutex
}
interfaceName = os.Getenv("INTERFACE") func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
remoteConfigURL = os.Getenv("REMOTE_CONFIG") wgClient, err := wgctrl.New()
listenAddr = os.Getenv("LISTEN")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
logLevel = os.Getenv("LOG_LEVEL")
mtu = os.Getenv("MTU")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
if remoteConfigURL == "" {
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
}
flag.Parse()
mtuInt, err = strconv.Atoi(mtu)
if err != nil { if err != nil {
logger.Fatal("Failed to parse MTU: %v", err) return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
} }
var key wgtypes.Key key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
reachableAt: reachableAt,
generateAndSaveKeyTo: generateAndSaveKeyTo,
lastReadings: make(map[string]PeerReading),
}
// Register websocket handlers
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.handleConnect)
return service, nil
}
func (s *WireGuardService) handleConnect() error {
logger.Debug("Public key: %s", s.key.PublicKey())
err := s.client.SendMessage("wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
}
func (s *WireGuardService) Start() error {
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if generateAndSaveKeyTo != "" { if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key // generate a new private key
key, err = wgtypes.GeneratePrivateKey() s.key, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
} }
// save the key to the file // save the key to the file
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
if err != nil { if err != nil {
logger.Fatal("Failed to save private key: %v", err) logger.Fatal("Failed to save private key: %v", err)
} }
} else { } else {
keyData, err := os.ReadFile(generateAndSaveKeyTo) keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
if err != nil { if err != nil {
logger.Fatal("Failed to read private key: %v", err) logger.Fatal("Failed to read private key: %v", err)
} }
key, err = wgtypes.ParseKey(string(keyData)) s.key, err = wgtypes.ParseKey(string(keyData))
if err != nil { if err != nil {
logger.Fatal("Failed to parse private key: %v", err) logger.Fatal("Failed to parse private key: %v", err)
} }
} }
} else {
// if no generateAndSaveKeyTo is provided, ensure that the private key is provided // Get initial configuration
if wgconfig.PrivateKey == "" { err := s.loadRemoteConfig()
// generate a new one
key, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) return fmt.Errorf("failed to load initial configuration: %v", err)
}
}
} }
// loop until we get the config return nil
for wgconfig.PrivateKey == "" {
logger.Info("Fetching remote config from %s", remoteConfigURL)
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
if err != nil {
logger.Error("Failed to load configuration: %v", err)
time.Sleep(5 * time.Second)
continue
}
wgconfig.PrivateKey = key.String()
}
wgClient, err = wgctrl.New()
if err != nil {
logger.Fatal("Failed to create WireGuard client: %v", err)
}
defer wgClient.Close()
// Ensure the WireGuard interface exists and is configured
if err := ensureWireguardInterface(wgconfig); err != nil {
logger.Fatal("Failed to ensure WireGuard interface: %v", err)
}
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
// go periodicBandwidthCheck(reportBandwidthTo)
} }
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { func (s *WireGuardService) StartBandwidthReporting() {
var body *bytes.Buffer go s.periodicBandwidthCheck()
if reachableAt == "" { }
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
} else {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt)))
}
resp, err := http.Post(url, "application/json", body)
if err != nil {
// print the error
logger.Error("Error fetching remote config %s: %v", url, err)
return WgConfig{}, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body) func (s *WireGuardService) loadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
// send a ws message to the server to get the config
err := s.client.SendMessage("wg/config/get", body)
if err != nil { if err != nil {
return WgConfig{}, err return fmt.Errorf("failed to send config request: %v", err)
} }
return nil
}
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
var config WgConfig var config WgConfig
err = json.Unmarshal(data, &config)
return config, err jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
s.config = config
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
}
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
} }
func ensureWireguardInterface(wgconfig WgConfig) error { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists // Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName) _, err := netlink.LinkByName(interfaceName)
if err != nil { if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok { if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it // Interface doesn't exist, so create it
err = createWireGuardInterface() err = s.createWireGuardInterface()
if err != nil { if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err) logger.Fatal("Failed to create WireGuard interface: %v", err)
} }
@@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
} }
// Assign IP address to the interface // Assign IP address to the interface
err = assignIPAddress(wgconfig.IpAddress) err = s.assignIPAddress(wgconfig.IpAddress)
if err != nil { if err != nil {
logger.Fatal("Failed to assign IP address: %v", err) logger.Fatal("Failed to assign IP address: %v", err)
} }
@@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
return fmt.Errorf("failed to bring up interface: %v", err) return fmt.Errorf("failed to bring up interface: %v", err)
} }
if err := ensureMSSClamping(); err != nil { if err := s.ensureMSSClamping(); err != nil {
logger.Warn("Failed to ensure MSS clamping: %v", err) logger.Warn("Failed to ensure MSS clamping: %v", err)
} }
@@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
return nil return nil
} }
func createWireGuardInterface() error { func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{ wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkType: "wireguard", LinkType: "wireguard",
@@ -274,7 +272,7 @@ func createWireGuardInterface() error {
return netlink.LinkAdd(wgLink) return netlink.LinkAdd(wgLink)
} }
func assignIPAddress(ipAddress string) error { func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName) link, err := netlink.LinkByName(interfaceName)
if err != nil { if err != nil {
return fmt.Errorf("failed to get interface: %v", err) return fmt.Errorf("failed to get interface: %v", err)
@@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error {
return netlink.AddrAdd(link, addr) return netlink.AddrAdd(link, addr)
} }
func ensureWireguardPeers(peers []Peer) error { func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers // get the current peers
device, err := wgClient.Device(interfaceName) device, err := wgClient.Device(interfaceName)
if err != nil { if err != nil {
@@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error {
} }
} }
if !found { if !found {
err := removePeer(peer) err := s.removePeer(peer)
if err != nil { if err != nil {
return fmt.Errorf("failed to remove peer: %v", err) return fmt.Errorf("failed to remove peer: %v", err)
} }
@@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error {
} }
} }
if !found { if !found {
err := addPeer(configPeer) err := s.addPeer(configPeer)
if err != nil { if err != nil {
return fmt.Errorf("failed to add peer: %v", err) return fmt.Errorf("failed to add peer: %v", err)
} }
@@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
func ensureMSSClamping() error { func (s *WireGuardService) ensureMSSClamping() error {
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
mssValue := mtuInt - 40 mssValue := mtuInt - 40
@@ -426,7 +424,7 @@ func ensureMSSClamping() error {
return nil return nil
} }
func handleAddPeer(msg websocket.WSMessage) { func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer var peer Peer
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
@@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) {
logger.Info("Error unmarshaling target data: %v", err) logger.Info("Error unmarshaling target data: %v", err)
} }
err = addPeer(peer) err = s.addPeer(peer)
if err != nil { if err != nil {
return return
} }
} }
func addPeer(peer Peer) error { func (s *WireGuardService) addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey) pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse public key: %v", err) return fmt.Errorf("failed to parse public key: %v", err)
@@ -478,7 +476,7 @@ func addPeer(peer Peer) error {
return nil return nil
} }
func handleRemovePeer(msg websocket.WSMessage) { func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
type RemoveRequest struct { type RemoveRequest struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
@@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) {
return return
} }
if err := removePeer(request.PublicKey); err != nil { if err := s.removePeer(request.PublicKey); err != nil {
logger.Info("Error removing peer: %v", err) logger.Info("Error removing peer: %v", err)
return return
} }
} }
func removePeer(publicKey string) error { func (s *WireGuardService) removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey) pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse public key: %v", err) return fmt.Errorf("failed to parse public key: %v", err)
@@ -525,18 +523,18 @@ func removePeer(publicKey string) error {
return nil return nil
} }
func periodicBandwidthCheck(endpoint string) { func (s *WireGuardService) periodicBandwidthCheck() {
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
if err := reportPeerBandwidth(endpoint); err != nil { if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err) logger.Info("Failed to report peer bandwidth: %v", err)
} }
} }
} }
func calculatePeerBandwidth() ([]PeerBandwidth, error) { func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := wgClient.Device(interfaceName) device, err := wgClient.Device(interfaceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err) return nil, fmt.Errorf("failed to get device: %v", err)
@@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
return peerBandwidths, nil return peerBandwidths, nil
} }
func reportPeerBandwidth(apiURL string) error { func (s *WireGuardService) reportPeerBandwidth() error {
bandwidths, err := calculatePeerBandwidth() bandwidths, err := s.calculatePeerBandwidth()
if err != nil { if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err) return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
} }
@@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error {
return fmt.Errorf("failed to marshal bandwidth data: %v", err) return fmt.Errorf("failed to marshal bandwidth data: %v", err)
} }
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) err = s.client.SendMessage("wg/bandwidth", jsonData)
if err != nil { if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err) return fmt.Errorf("failed to send bandwidth data: %v", err)
} }
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
}
return nil return nil
} }