mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Move wg into more of a class
This commit is contained in:
49
main.go
49
main.go
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/newt/wg"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
@@ -246,15 +247,18 @@ func resolveDomain(domain string) (string, error) {
|
||||
|
||||
func main() {
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
endpoint string
|
||||
id string
|
||||
secret string
|
||||
mtu string
|
||||
mtuInt int
|
||||
dns string
|
||||
privateKey wgtypes.Key
|
||||
err error
|
||||
logLevel string
|
||||
interfaceName string
|
||||
generateAndSaveKeyTo string
|
||||
reachableAt string
|
||||
)
|
||||
|
||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
||||
@@ -264,6 +268,9 @@ func main() {
|
||||
mtu = os.Getenv("MTU")
|
||||
dns = os.Getenv("DNS")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
reachableAt = os.Getenv("REACHABLE_AT")
|
||||
|
||||
if endpoint == "" {
|
||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||
@@ -283,6 +290,15 @@ func main() {
|
||||
if logLevel == "" {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
if reachableAt == "" {
|
||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
||||
}
|
||||
|
||||
// do a --version check
|
||||
version := flag.Bool("version", false, "Print the version")
|
||||
@@ -319,6 +335,21 @@ func main() {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Create WireGuard service
|
||||
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
// defer wgService.Close()
|
||||
|
||||
// Start the WireGuard service
|
||||
if err := wgService.Start(); err != nil {
|
||||
logger.Fatal("Failed to start WireGuard service: %v", err)
|
||||
}
|
||||
|
||||
// Start bandwidth reporting
|
||||
wgService.StartBandwidthReporting()
|
||||
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
|
||||
295
wg/wg.go
295
wg/wg.go
@@ -3,14 +3,10 @@ package wg
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -58,147 +54,149 @@ var (
|
||||
wgClient *wgctrl.Client
|
||||
)
|
||||
|
||||
func main() {
|
||||
var (
|
||||
err error
|
||||
wgconfig WgConfig
|
||||
remoteConfigURL string
|
||||
generateAndSaveKeyTo string
|
||||
reachableAt string
|
||||
logLevel string
|
||||
mtu string
|
||||
)
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
wgClient *wgctrl.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
reachableAt string
|
||||
generateAndSaveKeyTo string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
|
||||
listenAddr = os.Getenv("LISTEN")
|
||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||
reachableAt = os.Getenv("REACHABLE_AT")
|
||||
logLevel = os.Getenv("LOG_LEVEL")
|
||||
mtu = os.Getenv("MTU")
|
||||
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||
}
|
||||
if remoteConfigURL == "" {
|
||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
|
||||
}
|
||||
if generateAndSaveKeyTo == "" {
|
||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||
}
|
||||
if reachableAt == "" {
|
||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
||||
}
|
||||
if logLevel == "" {
|
||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||
}
|
||||
if mtu == "" {
|
||||
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
mtuInt, err = strconv.Atoi(mtu)
|
||||
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse MTU: %v", err)
|
||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||
}
|
||||
|
||||
var key wgtypes.Key
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
service := &WireGuardService{
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
wgClient: wgClient,
|
||||
key: key,
|
||||
reachableAt: reachableAt,
|
||||
generateAndSaveKeyTo: generateAndSaveKeyTo,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
}
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
|
||||
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
||||
|
||||
// Register connect handler to initiate configuration
|
||||
wsClient.OnConnect(service.handleConnect)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConnect() error {
|
||||
logger.Debug("Public key: %s", s.key.PublicKey())
|
||||
|
||||
err := s.client.SendMessage("wg/register", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Start() error {
|
||||
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if generateAndSaveKeyTo != "" {
|
||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if no generateAndSaveKeyTo is provided, ensure that the private key is provided
|
||||
if wgconfig.PrivateKey == "" {
|
||||
// generate a new one
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loop until we get the config
|
||||
for wgconfig.PrivateKey == "" {
|
||||
logger.Info("Fetching remote config from %s", remoteConfigURL)
|
||||
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
|
||||
if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
s.key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Error("Failed to load configuration: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
wgconfig.PrivateKey = key.String()
|
||||
}
|
||||
|
||||
wgClient, err = wgctrl.New()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard client: %v", err)
|
||||
}
|
||||
defer wgClient.Close()
|
||||
|
||||
// Ensure the WireGuard interface exists and is configured
|
||||
if err := ensureWireguardInterface(wgconfig); err != nil {
|
||||
logger.Fatal("Failed to ensure WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
// go periodicBandwidthCheck(reportBandwidthTo)
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
var body *bytes.Buffer
|
||||
if reachableAt == "" {
|
||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
|
||||
} else {
|
||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt)))
|
||||
keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
s.key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
resp, err := http.Post(url, "application/json", body)
|
||||
|
||||
// Get initial configuration
|
||||
err := s.loadRemoteConfig()
|
||||
if err != nil {
|
||||
// print the error
|
||||
logger.Error("Error fetching remote config %s: %v", url, err)
|
||||
return WgConfig{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return WgConfig{}, err
|
||||
return fmt.Errorf("failed to load initial configuration: %v", err)
|
||||
}
|
||||
|
||||
var config WgConfig
|
||||
err = json.Unmarshal(data, &config)
|
||||
|
||||
return config, err
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
func (s *WireGuardService) StartBandwidthReporting() {
|
||||
go s.periodicBandwidthCheck()
|
||||
}
|
||||
|
||||
func (s *WireGuardService) loadRemoteConfig() error {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
||||
|
||||
// send a ws message to the server to get the config
|
||||
|
||||
err := s.client.SendMessage("wg/config/get", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send config request: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
var config WgConfig
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &config); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
|
||||
// Ensure the WireGuard interface and peers are configured
|
||||
if err := s.ensureWireguardInterface(config); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Check if the WireGuard interface exists
|
||||
_, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||
// Interface doesn't exist, so create it
|
||||
err = createWireGuardInterface()
|
||||
err = s.createWireGuardInterface()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
@@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
}
|
||||
|
||||
// Assign IP address to the interface
|
||||
err = assignIPAddress(wgconfig.IpAddress)
|
||||
err = s.assignIPAddress(wgconfig.IpAddress)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to assign IP address: %v", err)
|
||||
}
|
||||
@@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
if err := ensureMSSClamping(); err != nil {
|
||||
if err := s.ensureMSSClamping(); err != nil {
|
||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
}
|
||||
|
||||
@@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func createWireGuardInterface() error {
|
||||
func (s *WireGuardService) createWireGuardInterface() error {
|
||||
wgLink := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
||||
LinkType: "wireguard",
|
||||
@@ -274,7 +272,7 @@ func createWireGuardInterface() error {
|
||||
return netlink.LinkAdd(wgLink)
|
||||
}
|
||||
|
||||
func assignIPAddress(ipAddress string) error {
|
||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
@@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error {
|
||||
return netlink.AddrAdd(link, addr)
|
||||
}
|
||||
|
||||
func ensureWireguardPeers(peers []Peer) error {
|
||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
// get the current peers
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
@@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := removePeer(peer)
|
||||
err := s.removePeer(peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
@@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := addPeer(configPeer)
|
||||
err := s.addPeer(configPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
@@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureMSSClamping() error {
|
||||
func (s *WireGuardService) ensureMSSClamping() error {
|
||||
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||
mssValue := mtuInt - 40
|
||||
|
||||
@@ -426,7 +424,7 @@ func ensureMSSClamping() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleAddPeer(msg websocket.WSMessage) {
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
var peer Peer
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
@@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
err = addPeer(peer)
|
||||
err = s.addPeer(peer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func addPeer(peer Peer) error {
|
||||
func (s *WireGuardService) addPeer(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
@@ -478,7 +476,7 @@ func addPeer(peer Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleRemovePeer(msg websocket.WSMessage) {
|
||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||
type RemoveRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
@@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := removePeer(request.PublicKey); err != nil {
|
||||
if err := s.removePeer(request.PublicKey); err != nil {
|
||||
logger.Info("Error removing peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func removePeer(publicKey string) error {
|
||||
func (s *WireGuardService) removePeer(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
@@ -525,18 +523,18 @@ func removePeer(publicKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(endpoint string) {
|
||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
if err := s.reportPeerBandwidth(); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
@@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
return peerBandwidths, nil
|
||||
}
|
||||
|
||||
func reportPeerBandwidth(apiURL string) error {
|
||||
bandwidths, err := calculatePeerBandwidth()
|
||||
func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
bandwidths, err := s.calculatePeerBandwidth()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
@@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error {
|
||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
|
||||
err = s.client.SendMessage("wg/bandwidth", jsonData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user