Move wg into more of a class

This commit is contained in:
Owen
2025-02-20 20:37:31 -05:00
parent e8bd55bed9
commit f69a7f647d
2 changed files with 184 additions and 160 deletions

49
main.go
View File

@@ -20,6 +20,7 @@ import (
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy" "github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket" "github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
@@ -246,15 +247,18 @@ func resolveDomain(domain string) (string, error) {
func main() { func main() {
var ( var (
endpoint string endpoint string
id string id string
secret string secret string
mtu string mtu string
mtuInt int mtuInt int
dns string dns string
privateKey wgtypes.Key privateKey wgtypes.Key
err error err error
logLevel string logLevel string
interfaceName string
generateAndSaveKeyTo string
reachableAt string
) )
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
@@ -264,6 +268,9 @@ func main() {
mtu = os.Getenv("MTU") mtu = os.Getenv("MTU")
dns = os.Getenv("DNS") dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL") logLevel = os.Getenv("LOG_LEVEL")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -283,6 +290,15 @@ func main() {
if logLevel == "" { if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
} }
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@@ -319,6 +335,21 @@ func main() {
logger.Fatal("Failed to create client: %v", err) logger.Fatal("Failed to create client: %v", err)
} }
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
// defer wgService.Close()
// Start the WireGuard service
if err := wgService.Start(); err != nil {
logger.Fatal("Failed to start WireGuard service: %v", err)
}
// Start bandwidth reporting
wgService.StartBandwidthReporting()
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device
var tnet *netstack.Net var tnet *netstack.Net

295
wg/wg.go
View File

@@ -3,14 +3,10 @@ package wg
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"flag"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"os" "os"
"os/exec" "os/exec"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -58,147 +54,149 @@ var (
wgClient *wgctrl.Client wgClient *wgctrl.Client
) )
func main() { type WireGuardService struct {
var ( interfaceName string
err error mtu int
wgconfig WgConfig client *websocket.Client
remoteConfigURL string wgClient *wgctrl.Client
generateAndSaveKeyTo string config WgConfig
reachableAt string key wgtypes.Key
logLevel string reachableAt string
mtu string generateAndSaveKeyTo string
) lastReadings map[string]PeerReading
mu sync.Mutex
}
interfaceName = os.Getenv("INTERFACE") func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
remoteConfigURL = os.Getenv("REMOTE_CONFIG") wgClient, err := wgctrl.New()
listenAddr = os.Getenv("LISTEN")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
logLevel = os.Getenv("LOG_LEVEL")
mtu = os.Getenv("MTU")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
if remoteConfigURL == "" {
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
}
flag.Parse()
mtuInt, err = strconv.Atoi(mtu)
if err != nil { if err != nil {
logger.Fatal("Failed to parse MTU: %v", err) return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
} }
var key wgtypes.Key key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
reachableAt: reachableAt,
generateAndSaveKeyTo: generateAndSaveKeyTo,
lastReadings: make(map[string]PeerReading),
}
// Register websocket handlers
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.handleConnect)
return service, nil
}
func (s *WireGuardService) handleConnect() error {
logger.Debug("Public key: %s", s.key.PublicKey())
err := s.client.SendMessage("wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
}
func (s *WireGuardService) Start() error {
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if generateAndSaveKeyTo != "" { if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { // generate a new private key
// generate a new private key s.key, err = wgtypes.GeneratePrivateKey()
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// save the key to the file
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
}
} else {
keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
}
} else {
// if no generateAndSaveKeyTo is provided, ensure that the private key is provided
if wgconfig.PrivateKey == "" {
// generate a new one
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
}
}
// loop until we get the config
for wgconfig.PrivateKey == "" {
logger.Info("Fetching remote config from %s", remoteConfigURL)
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
if err != nil { if err != nil {
logger.Error("Failed to load configuration: %v", err) logger.Fatal("Failed to generate private key: %v", err)
time.Sleep(5 * time.Second) }
continue // save the key to the file
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
} }
wgconfig.PrivateKey = key.String()
}
wgClient, err = wgctrl.New()
if err != nil {
logger.Fatal("Failed to create WireGuard client: %v", err)
}
defer wgClient.Close()
// Ensure the WireGuard interface exists and is configured
if err := ensureWireguardInterface(wgconfig); err != nil {
logger.Fatal("Failed to ensure WireGuard interface: %v", err)
}
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
// go periodicBandwidthCheck(reportBandwidthTo)
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
var body *bytes.Buffer
if reachableAt == "" {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
} else { } else {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
s.key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
} }
resp, err := http.Post(url, "application/json", body)
// Get initial configuration
err := s.loadRemoteConfig()
if err != nil { if err != nil {
// print the error return fmt.Errorf("failed to load initial configuration: %v", err)
logger.Error("Error fetching remote config %s: %v", url, err)
return WgConfig{}, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return WgConfig{}, err
} }
var config WgConfig return nil
err = json.Unmarshal(data, &config)
return config, err
} }
func ensureWireguardInterface(wgconfig WgConfig) error { func (s *WireGuardService) StartBandwidthReporting() {
go s.periodicBandwidthCheck()
}
func (s *WireGuardService) loadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
// send a ws message to the server to get the config
err := s.client.SendMessage("wg/config/get", body)
if err != nil {
return fmt.Errorf("failed to send config request: %v", err)
}
return nil
}
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
var config WgConfig
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
s.config = config
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
}
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists // Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName) _, err := netlink.LinkByName(interfaceName)
if err != nil { if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok { if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it // Interface doesn't exist, so create it
err = createWireGuardInterface() err = s.createWireGuardInterface()
if err != nil { if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err) logger.Fatal("Failed to create WireGuard interface: %v", err)
} }
@@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
} }
// Assign IP address to the interface // Assign IP address to the interface
err = assignIPAddress(wgconfig.IpAddress) err = s.assignIPAddress(wgconfig.IpAddress)
if err != nil { if err != nil {
logger.Fatal("Failed to assign IP address: %v", err) logger.Fatal("Failed to assign IP address: %v", err)
} }
@@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
return fmt.Errorf("failed to bring up interface: %v", err) return fmt.Errorf("failed to bring up interface: %v", err)
} }
if err := ensureMSSClamping(); err != nil { if err := s.ensureMSSClamping(); err != nil {
logger.Warn("Failed to ensure MSS clamping: %v", err) logger.Warn("Failed to ensure MSS clamping: %v", err)
} }
@@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
return nil return nil
} }
func createWireGuardInterface() error { func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{ wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkType: "wireguard", LinkType: "wireguard",
@@ -274,7 +272,7 @@ func createWireGuardInterface() error {
return netlink.LinkAdd(wgLink) return netlink.LinkAdd(wgLink)
} }
func assignIPAddress(ipAddress string) error { func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName) link, err := netlink.LinkByName(interfaceName)
if err != nil { if err != nil {
return fmt.Errorf("failed to get interface: %v", err) return fmt.Errorf("failed to get interface: %v", err)
@@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error {
return netlink.AddrAdd(link, addr) return netlink.AddrAdd(link, addr)
} }
func ensureWireguardPeers(peers []Peer) error { func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers // get the current peers
device, err := wgClient.Device(interfaceName) device, err := wgClient.Device(interfaceName)
if err != nil { if err != nil {
@@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error {
} }
} }
if !found { if !found {
err := removePeer(peer) err := s.removePeer(peer)
if err != nil { if err != nil {
return fmt.Errorf("failed to remove peer: %v", err) return fmt.Errorf("failed to remove peer: %v", err)
} }
@@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error {
} }
} }
if !found { if !found {
err := addPeer(configPeer) err := s.addPeer(configPeer)
if err != nil { if err != nil {
return fmt.Errorf("failed to add peer: %v", err) return fmt.Errorf("failed to add peer: %v", err)
} }
@@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
func ensureMSSClamping() error { func (s *WireGuardService) ensureMSSClamping() error {
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
mssValue := mtuInt - 40 mssValue := mtuInt - 40
@@ -426,7 +424,7 @@ func ensureMSSClamping() error {
return nil return nil
} }
func handleAddPeer(msg websocket.WSMessage) { func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer var peer Peer
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
@@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) {
logger.Info("Error unmarshaling target data: %v", err) logger.Info("Error unmarshaling target data: %v", err)
} }
err = addPeer(peer) err = s.addPeer(peer)
if err != nil { if err != nil {
return return
} }
} }
func addPeer(peer Peer) error { func (s *WireGuardService) addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey) pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse public key: %v", err) return fmt.Errorf("failed to parse public key: %v", err)
@@ -478,7 +476,7 @@ func addPeer(peer Peer) error {
return nil return nil
} }
func handleRemovePeer(msg websocket.WSMessage) { func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
type RemoveRequest struct { type RemoveRequest struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
@@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) {
return return
} }
if err := removePeer(request.PublicKey); err != nil { if err := s.removePeer(request.PublicKey); err != nil {
logger.Info("Error removing peer: %v", err) logger.Info("Error removing peer: %v", err)
return return
} }
} }
func removePeer(publicKey string) error { func (s *WireGuardService) removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey) pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse public key: %v", err) return fmt.Errorf("failed to parse public key: %v", err)
@@ -525,18 +523,18 @@ func removePeer(publicKey string) error {
return nil return nil
} }
func periodicBandwidthCheck(endpoint string) { func (s *WireGuardService) periodicBandwidthCheck() {
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
if err := reportPeerBandwidth(endpoint); err != nil { if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err) logger.Info("Failed to report peer bandwidth: %v", err)
} }
} }
} }
func calculatePeerBandwidth() ([]PeerBandwidth, error) { func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := wgClient.Device(interfaceName) device, err := wgClient.Device(interfaceName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err) return nil, fmt.Errorf("failed to get device: %v", err)
@@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
return peerBandwidths, nil return peerBandwidths, nil
} }
func reportPeerBandwidth(apiURL string) error { func (s *WireGuardService) reportPeerBandwidth() error {
bandwidths, err := calculatePeerBandwidth() bandwidths, err := s.calculatePeerBandwidth()
if err != nil { if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err) return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
} }
@@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error {
return fmt.Errorf("failed to marshal bandwidth data: %v", err) return fmt.Errorf("failed to marshal bandwidth data: %v", err)
} }
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) err = s.client.SendMessage("wg/bandwidth", jsonData)
if err != nil { if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err) return fmt.Errorf("failed to send bandwidth data: %v", err)
} }
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
}
return nil return nil
} }