Move wg into more of a class

This commit is contained in:
Owen
2025-02-20 20:37:31 -05:00
parent e8bd55bed9
commit f69a7f647d
2 changed files with 184 additions and 160 deletions

49
main.go
View File

@@ -20,6 +20,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
@@ -246,15 +247,18 @@ func resolveDomain(domain string) (string, error) {
func main() {
var (
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
endpoint string
id string
secret string
mtu string
mtuInt int
dns string
privateKey wgtypes.Key
err error
logLevel string
interfaceName string
generateAndSaveKeyTo string
reachableAt string
)
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
@@ -264,6 +268,9 @@ func main() {
mtu = os.Getenv("MTU")
dns = os.Getenv("DNS")
logLevel = os.Getenv("LOG_LEVEL")
interfaceName = os.Getenv("INTERFACE")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -283,6 +290,15 @@ func main() {
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -319,6 +335,21 @@ func main() {
logger.Fatal("Failed to create client: %v", err)
}
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
// defer wgService.Close()
// Start the WireGuard service
if err := wgService.Start(); err != nil {
logger.Fatal("Failed to start WireGuard service: %v", err)
}
// Start bandwidth reporting
wgService.StartBandwidthReporting()
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net

295
wg/wg.go
View File

@@ -3,14 +3,10 @@ package wg
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"net"
"net/http"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"time"
@@ -58,147 +54,149 @@ var (
wgClient *wgctrl.Client
)
func main() {
var (
err error
wgconfig WgConfig
remoteConfigURL string
generateAndSaveKeyTo string
reachableAt string
logLevel string
mtu string
)
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
wgClient *wgctrl.Client
config WgConfig
key wgtypes.Key
reachableAt string
generateAndSaveKeyTo string
lastReadings map[string]PeerReading
mu sync.Mutex
}
interfaceName = os.Getenv("INTERFACE")
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
listenAddr = os.Getenv("LISTEN")
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
reachableAt = os.Getenv("REACHABLE_AT")
logLevel = os.Getenv("LOG_LEVEL")
mtu = os.Getenv("MTU")
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
if remoteConfigURL == "" {
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
}
if generateAndSaveKeyTo == "" {
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
}
if reachableAt == "" {
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
}
if logLevel == "" {
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
}
if mtu == "" {
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
}
flag.Parse()
mtuInt, err = strconv.Atoi(mtu)
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
wgClient, err := wgctrl.New()
if err != nil {
logger.Fatal("Failed to parse MTU: %v", err)
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
}
var key wgtypes.Key
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
reachableAt: reachableAt,
generateAndSaveKeyTo: generateAndSaveKeyTo,
lastReadings: make(map[string]PeerReading),
}
// Register websocket handlers
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.handleConnect)
return service, nil
}
func (s *WireGuardService) handleConnect() error {
logger.Debug("Public key: %s", s.key.PublicKey())
err := s.client.SendMessage("wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
}
func (s *WireGuardService) Start() error {
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if generateAndSaveKeyTo != "" {
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// save the key to the file
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
}
} else {
keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
}
} else {
// if no generateAndSaveKeyTo is provided, ensure that the private key is provided
if wgconfig.PrivateKey == "" {
// generate a new one
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
}
}
// loop until we get the config
for wgconfig.PrivateKey == "" {
logger.Info("Fetching remote config from %s", remoteConfigURL)
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key
s.key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Error("Failed to load configuration: %v", err)
time.Sleep(5 * time.Second)
continue
logger.Fatal("Failed to generate private key: %v", err)
}
// save the key to the file
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
}
wgconfig.PrivateKey = key.String()
}
wgClient, err = wgctrl.New()
if err != nil {
logger.Fatal("Failed to create WireGuard client: %v", err)
}
defer wgClient.Close()
// Ensure the WireGuard interface exists and is configured
if err := ensureWireguardInterface(wgconfig); err != nil {
logger.Fatal("Failed to ensure WireGuard interface: %v", err)
}
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
// go periodicBandwidthCheck(reportBandwidthTo)
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
var body *bytes.Buffer
if reachableAt == "" {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
} else {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt)))
keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
s.key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
}
resp, err := http.Post(url, "application/json", body)
// Get initial configuration
err := s.loadRemoteConfig()
if err != nil {
// print the error
logger.Error("Error fetching remote config %s: %v", url, err)
return WgConfig{}, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return WgConfig{}, err
return fmt.Errorf("failed to load initial configuration: %v", err)
}
var config WgConfig
err = json.Unmarshal(data, &config)
return config, err
return nil
}
func ensureWireguardInterface(wgconfig WgConfig) error {
func (s *WireGuardService) StartBandwidthReporting() {
go s.periodicBandwidthCheck()
}
func (s *WireGuardService) loadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
// send a ws message to the server to get the config
err := s.client.SendMessage("wg/config/get", body)
if err != nil {
return fmt.Errorf("failed to send config request: %v", err)
}
return nil
}
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
var config WgConfig
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
s.config = config
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
}
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
err = createWireGuardInterface()
err = s.createWireGuardInterface()
if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err)
}
@@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
}
// Assign IP address to the interface
err = assignIPAddress(wgconfig.IpAddress)
err = s.assignIPAddress(wgconfig.IpAddress)
if err != nil {
logger.Fatal("Failed to assign IP address: %v", err)
}
@@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
return fmt.Errorf("failed to bring up interface: %v", err)
}
if err := ensureMSSClamping(); err != nil {
if err := s.ensureMSSClamping(); err != nil {
logger.Warn("Failed to ensure MSS clamping: %v", err)
}
@@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
return nil
}
func createWireGuardInterface() error {
func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkType: "wireguard",
@@ -274,7 +272,7 @@ func createWireGuardInterface() error {
return netlink.LinkAdd(wgLink)
}
func assignIPAddress(ipAddress string) error {
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
@@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error {
return netlink.AddrAdd(link, addr)
}
func ensureWireguardPeers(peers []Peer) error {
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers
device, err := wgClient.Device(interfaceName)
if err != nil {
@@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error {
}
}
if !found {
err := removePeer(peer)
err := s.removePeer(peer)
if err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
@@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error {
}
}
if !found {
err := addPeer(configPeer)
err := s.addPeer(configPeer)
if err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
@@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error {
return nil
}
func ensureMSSClamping() error {
func (s *WireGuardService) ensureMSSClamping() error {
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
mssValue := mtuInt - 40
@@ -426,7 +424,7 @@ func ensureMSSClamping() error {
return nil
}
func handleAddPeer(msg websocket.WSMessage) {
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer
jsonData, err := json.Marshal(msg.Data)
@@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) {
logger.Info("Error unmarshaling target data: %v", err)
}
err = addPeer(peer)
err = s.addPeer(peer)
if err != nil {
return
}
}
func addPeer(peer Peer) error {
func (s *WireGuardService) addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
@@ -478,7 +476,7 @@ func addPeer(peer Peer) error {
return nil
}
func handleRemovePeer(msg websocket.WSMessage) {
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
type RemoveRequest struct {
PublicKey string `json:"publicKey"`
@@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) {
return
}
if err := removePeer(request.PublicKey); err != nil {
if err := s.removePeer(request.PublicKey); err != nil {
logger.Info("Error removing peer: %v", err)
return
}
}
func removePeer(publicKey string) error {
func (s *WireGuardService) removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
@@ -525,18 +523,18 @@ func removePeer(publicKey string) error {
return nil
}
func periodicBandwidthCheck(endpoint string) {
func (s *WireGuardService) periodicBandwidthCheck() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := reportPeerBandwidth(endpoint); err != nil {
if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
}
}
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := wgClient.Device(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
@@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
return peerBandwidths, nil
}
func reportPeerBandwidth(apiURL string) error {
bandwidths, err := calculatePeerBandwidth()
func (s *WireGuardService) reportPeerBandwidth() error {
bandwidths, err := s.calculatePeerBandwidth()
if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
@@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error {
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
}
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
err = s.client.SendMessage("wg/bandwidth", jsonData)
if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
}
return nil
}