mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Clean up implementation
This commit is contained in:
10
main.go
10
main.go
@@ -340,15 +340,7 @@ func main() {
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
// defer wgService.Close()
|
||||
|
||||
// Start the WireGuard service
|
||||
if err := wgService.Start(); err != nil {
|
||||
logger.Fatal("Failed to start WireGuard service: %v", err)
|
||||
}
|
||||
|
||||
// Start bandwidth reporting
|
||||
wgService.StartBandwidthReporting()
|
||||
defer wgService.Close()
|
||||
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
|
||||
245
wg/wg.go
245
wg/wg.go
@@ -6,8 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -62,7 +60,6 @@ type WireGuardService struct {
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
reachableAt string
|
||||
generateAndSaveKeyTo string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
}
|
||||
@@ -73,9 +70,28 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
|
||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
key := wgtypes.Key{}
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
service := &WireGuardService{
|
||||
@@ -85,78 +101,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
|
||||
wgClient: wgClient,
|
||||
key: key,
|
||||
reachableAt: reachableAt,
|
||||
generateAndSaveKeyTo: generateAndSaveKeyTo,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
}
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
|
||||
wsClient.RegisterHandler("wg/config/receive", service.handleConfig)
|
||||
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
||||
|
||||
// Register connect handler to initiate configuration
|
||||
wsClient.OnConnect(service.handleConnect)
|
||||
wsClient.OnConnect(service.loadRemoteConfig)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConnect() error {
|
||||
logger.Debug("Public key: %s", s.key.PublicKey())
|
||||
|
||||
err := s.client.SendMessage("wg/register", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Sent registration message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Start() error {
|
||||
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
s.key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
s.key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Get initial configuration
|
||||
err := s.loadRemoteConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load initial configuration: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) StartBandwidthReporting() {
|
||||
go s.periodicBandwidthCheck()
|
||||
func (s *WireGuardService) Close() {
|
||||
s.client.Close()
|
||||
wgClient.Close()
|
||||
}
|
||||
|
||||
func (s *WireGuardService) loadRemoteConfig() error {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
||||
|
||||
// send a ws message to the server to get the config
|
||||
go s.periodicBandwidthCheck()
|
||||
|
||||
err := s.client.SendMessage("wg/config/get", body)
|
||||
if err != nil {
|
||||
@@ -255,9 +222,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ensureMSSClamping(); err != nil {
|
||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
}
|
||||
// if err := s.ensureMSSClamping(); err != nil {
|
||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
// }
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
|
||||
@@ -336,93 +303,93 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureMSSClamping() error {
|
||||
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||
mssValue := mtuInt - 40
|
||||
// func (s *WireGuardService) ensureMSSClamping() error {
|
||||
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||
// mssValue := mtuInt - 40
|
||||
|
||||
// Rules to be managed - just the chains, we'll construct the full command separately
|
||||
chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
||||
// // Rules to be managed - just the chains, we'll construct the full command separately
|
||||
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
||||
|
||||
// First, try to delete any existing rules
|
||||
for _, chain := range chains {
|
||||
deleteCmd := exec.Command("/usr/sbin/iptables",
|
||||
"-t", "mangle",
|
||||
"-D", chain,
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
// // First, try to delete any existing rules
|
||||
// for _, chain := range chains {
|
||||
// deleteCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-D", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
||||
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
||||
|
||||
// Try deletion multiple times to handle multiple existing rules
|
||||
for i := 0; i < 3; i++ {
|
||||
out, err := deleteCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// Convert exit status 1 to string for better logging
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
||||
chain, exitErr.String(), string(out))
|
||||
}
|
||||
break // No more rules to delete
|
||||
}
|
||||
logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
||||
}
|
||||
}
|
||||
// // Try deletion multiple times to handle multiple existing rules
|
||||
// for i := 0; i < 3; i++ {
|
||||
// out, err := deleteCmd.CombinedOutput()
|
||||
// if err != nil {
|
||||
// // Convert exit status 1 to string for better logging
|
||||
// if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
||||
// chain, exitErr.String(), string(out))
|
||||
// }
|
||||
// break // No more rules to delete
|
||||
// }
|
||||
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
||||
// }
|
||||
// }
|
||||
|
||||
// Then add the new rules
|
||||
var errors []error
|
||||
for _, chain := range chains {
|
||||
addCmd := exec.Command("/usr/sbin/iptables",
|
||||
"-t", "mangle",
|
||||
"-A", chain,
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
// // Then add the new rules
|
||||
// var errors []error
|
||||
// for _, chain := range chains {
|
||||
// addCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-A", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
logger.Info("Adding MSS clamping rule for chain %s", chain)
|
||||
// logger.Info("Adding MSS clamping rule for chain %s", chain)
|
||||
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf(errMsg))
|
||||
continue
|
||||
}
|
||||
// if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
// chain, err, string(out))
|
||||
// logger.Error(errMsg)
|
||||
// errors = append(errors, fmt.Errorf(errMsg))
|
||||
// continue
|
||||
// }
|
||||
|
||||
// Verify the rule was added
|
||||
checkCmd := exec.Command("/usr/sbin/iptables",
|
||||
"-t", "mangle",
|
||||
"-C", chain,
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
// // Verify the rule was added
|
||||
// checkCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-C", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
errors = append(errors, fmt.Errorf(errMsg))
|
||||
continue
|
||||
}
|
||||
// if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
// chain, err, string(out))
|
||||
// logger.Error(errMsg)
|
||||
// errors = append(errors, fmt.Errorf(errMsg))
|
||||
// continue
|
||||
// }
|
||||
|
||||
logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
||||
}
|
||||
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
||||
// }
|
||||
|
||||
// If we encountered any errors, return them combined
|
||||
if len(errors) > 0 {
|
||||
var errMsgs []string
|
||||
for _, err := range errors {
|
||||
errMsgs = append(errMsgs, err.Error())
|
||||
}
|
||||
return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
||||
strings.Join(errMsgs, "\n"))
|
||||
}
|
||||
// // If we encountered any errors, return them combined
|
||||
// if len(errors) > 0 {
|
||||
// var errMsgs []string
|
||||
// for _, err := range errors {
|
||||
// errMsgs = append(errMsgs, err.Error())
|
||||
// }
|
||||
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
||||
// strings.Join(errMsgs, "\n"))
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
var peer Peer
|
||||
|
||||
Reference in New Issue
Block a user