Clean up implementation

This commit is contained in:
Owen
2025-02-20 21:01:44 -05:00
parent f69a7f647d
commit 66edae4288
2 changed files with 116 additions and 157 deletions

10
main.go
View File

@@ -340,15 +340,7 @@ func main() {
if err != nil { if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err) logger.Fatal("Failed to create WireGuard service: %v", err)
} }
// defer wgService.Close() defer wgService.Close()
// Start the WireGuard service
if err := wgService.Start(); err != nil {
logger.Fatal("Failed to start WireGuard service: %v", err)
}
// Start bandwidth reporting
wgService.StartBandwidthReporting()
// Create TUN device and network stack // Create TUN device and network stack
var tun tun.Device var tun tun.Device

263
wg/wg.go
View File

@@ -6,8 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"os/exec"
"strings"
"sync" "sync"
"time" "time"
@@ -55,16 +53,15 @@ var (
) )
type WireGuardService struct { type WireGuardService struct {
interfaceName string interfaceName string
mtu int mtu int
client *websocket.Client client *websocket.Client
wgClient *wgctrl.Client wgClient *wgctrl.Client
config WgConfig config WgConfig
key wgtypes.Key key wgtypes.Key
reachableAt string reachableAt string
generateAndSaveKeyTo string lastReadings map[string]PeerReading
lastReadings map[string]PeerReading mu sync.Mutex
mu sync.Mutex
} }
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
@@ -73,90 +70,60 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
return nil, fmt.Errorf("failed to create WireGuard client: %v", err) return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
} }
key, err := wgtypes.GeneratePrivateKey() key := wgtypes.Key{}
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %v", err)
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
reachableAt: reachableAt,
generateAndSaveKeyTo: generateAndSaveKeyTo,
lastReadings: make(map[string]PeerReading),
}
// Register websocket handlers
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.handleConnect)
return service, nil
}
func (s *WireGuardService) handleConnect() error {
logger.Debug("Public key: %s", s.key.PublicKey())
err := s.client.SendMessage("wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
}
func (s *WireGuardService) Start() error {
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) { if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key // generate a new private key
s.key, err = wgtypes.GeneratePrivateKey() key, err = wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
logger.Fatal("Failed to generate private key: %v", err) logger.Fatal("Failed to generate private key: %v", err)
} }
// save the key to the file // save the key to the file
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644) err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil { if err != nil {
logger.Fatal("Failed to save private key: %v", err) logger.Fatal("Failed to save private key: %v", err)
} }
} else { } else {
keyData, err := os.ReadFile(s.generateAndSaveKeyTo) keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil { if err != nil {
logger.Fatal("Failed to read private key: %v", err) logger.Fatal("Failed to read private key: %v", err)
} }
s.key, err = wgtypes.ParseKey(string(keyData)) key, err = wgtypes.ParseKey(string(keyData))
if err != nil { if err != nil {
logger.Fatal("Failed to parse private key: %v", err) logger.Fatal("Failed to parse private key: %v", err)
} }
} }
// Get initial configuration service := &WireGuardService{
err := s.loadRemoteConfig() interfaceName: interfaceName,
if err != nil { mtu: mtu,
return fmt.Errorf("failed to load initial configuration: %v", err) client: wsClient,
wgClient: wgClient,
key: key,
reachableAt: reachableAt,
lastReadings: make(map[string]PeerReading),
} }
return nil // Register websocket handlers
wsClient.RegisterHandler("wg/config/receive", service.handleConfig)
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.loadRemoteConfig)
return service, nil
} }
func (s *WireGuardService) StartBandwidthReporting() { func (s *WireGuardService) Close() {
go s.periodicBandwidthCheck() s.client.Close()
wgClient.Close()
} }
func (s *WireGuardService) loadRemoteConfig() error { func (s *WireGuardService) loadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
// send a ws message to the server to get the config go s.periodicBandwidthCheck()
err := s.client.SendMessage("wg/config/get", body) err := s.client.SendMessage("wg/config/get", body)
if err != nil { if err != nil {
@@ -255,9 +222,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
return fmt.Errorf("failed to bring up interface: %v", err) return fmt.Errorf("failed to bring up interface: %v", err)
} }
if err := s.ensureMSSClamping(); err != nil { // if err := s.ensureMSSClamping(); err != nil {
logger.Warn("Failed to ensure MSS clamping: %v", err) // logger.Warn("Failed to ensure MSS clamping: %v", err)
} // }
logger.Info("WireGuard interface %s created and configured", interfaceName) logger.Info("WireGuard interface %s created and configured", interfaceName)
@@ -336,93 +303,93 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
return nil return nil
} }
func (s *WireGuardService) ensureMSSClamping() error { // func (s *WireGuardService) ensureMSSClamping() error {
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) // // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
mssValue := mtuInt - 40 // mssValue := mtuInt - 40
// Rules to be managed - just the chains, we'll construct the full command separately // // Rules to be managed - just the chains, we'll construct the full command separately
chains := []string{"INPUT", "OUTPUT", "FORWARD"} // chains := []string{"INPUT", "OUTPUT", "FORWARD"}
// First, try to delete any existing rules // // First, try to delete any existing rules
for _, chain := range chains { // for _, chain := range chains {
deleteCmd := exec.Command("/usr/sbin/iptables", // deleteCmd := exec.Command("/usr/sbin/iptables",
"-t", "mangle", // "-t", "mangle",
"-D", chain, // "-D", chain,
"-p", "tcp", // "-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN", // "--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS", // "-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mssValue)) // "--set-mss", fmt.Sprintf("%d", mssValue))
logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) // logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
// Try deletion multiple times to handle multiple existing rules // // Try deletion multiple times to handle multiple existing rules
for i := 0; i < 3; i++ { // for i := 0; i < 3; i++ {
out, err := deleteCmd.CombinedOutput() // out, err := deleteCmd.CombinedOutput()
if err != nil { // if err != nil {
// Convert exit status 1 to string for better logging // // Convert exit status 1 to string for better logging
if exitErr, ok := err.(*exec.ExitError); ok { // if exitErr, ok := err.(*exec.ExitError); ok {
logger.Debug("Deletion stopped for chain %s: %v (output: %s)", // logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
chain, exitErr.String(), string(out)) // chain, exitErr.String(), string(out))
} // }
break // No more rules to delete // break // No more rules to delete
} // }
logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) // logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
} // }
} // }
// Then add the new rules // // Then add the new rules
var errors []error // var errors []error
for _, chain := range chains { // for _, chain := range chains {
addCmd := exec.Command("/usr/sbin/iptables", // addCmd := exec.Command("/usr/sbin/iptables",
"-t", "mangle", // "-t", "mangle",
"-A", chain, // "-A", chain,
"-p", "tcp", // "-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN", // "--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS", // "-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mssValue)) // "--set-mss", fmt.Sprintf("%d", mssValue))
logger.Info("Adding MSS clamping rule for chain %s", chain) // logger.Info("Adding MSS clamping rule for chain %s", chain)
if out, err := addCmd.CombinedOutput(); err != nil { // if out, err := addCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", // errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out)) // chain, err, string(out))
logger.Error(errMsg) // logger.Error(errMsg)
errors = append(errors, fmt.Errorf(errMsg)) // errors = append(errors, fmt.Errorf(errMsg))
continue // continue
} // }
// Verify the rule was added // // Verify the rule was added
checkCmd := exec.Command("/usr/sbin/iptables", // checkCmd := exec.Command("/usr/sbin/iptables",
"-t", "mangle", // "-t", "mangle",
"-C", chain, // "-C", chain,
"-p", "tcp", // "-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN", // "--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS", // "-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mssValue)) // "--set-mss", fmt.Sprintf("%d", mssValue))
if out, err := checkCmd.CombinedOutput(); err != nil { // if out, err := checkCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", // errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out)) // chain, err, string(out))
logger.Error(errMsg) // logger.Error(errMsg)
errors = append(errors, fmt.Errorf(errMsg)) // errors = append(errors, fmt.Errorf(errMsg))
continue // continue
} // }
logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) // logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
} // }
// If we encountered any errors, return them combined // // If we encountered any errors, return them combined
if len(errors) > 0 { // if len(errors) > 0 {
var errMsgs []string // var errMsgs []string
for _, err := range errors { // for _, err := range errors {
errMsgs = append(errMsgs, err.Error()) // errMsgs = append(errMsgs, err.Error())
} // }
return fmt.Errorf("MSS clamping setup encountered errors:\n%s", // return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
strings.Join(errMsgs, "\n")) // strings.Join(errMsgs, "\n"))
} // }
return nil // return nil
} // }
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer var peer Peer