mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Clean up implementation
This commit is contained in:
10
main.go
10
main.go
@@ -340,15 +340,7 @@ func main() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||||
}
|
}
|
||||||
// defer wgService.Close()
|
defer wgService.Close()
|
||||||
|
|
||||||
// Start the WireGuard service
|
|
||||||
if err := wgService.Start(); err != nil {
|
|
||||||
logger.Fatal("Failed to start WireGuard service: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start bandwidth reporting
|
|
||||||
wgService.StartBandwidthReporting()
|
|
||||||
|
|
||||||
// Create TUN device and network stack
|
// Create TUN device and network stack
|
||||||
var tun tun.Device
|
var tun tun.Device
|
||||||
|
|||||||
263
wg/wg.go
263
wg/wg.go
@@ -6,8 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -55,16 +53,15 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WireGuardService struct {
|
type WireGuardService struct {
|
||||||
interfaceName string
|
interfaceName string
|
||||||
mtu int
|
mtu int
|
||||||
client *websocket.Client
|
client *websocket.Client
|
||||||
wgClient *wgctrl.Client
|
wgClient *wgctrl.Client
|
||||||
config WgConfig
|
config WgConfig
|
||||||
key wgtypes.Key
|
key wgtypes.Key
|
||||||
reachableAt string
|
reachableAt string
|
||||||
generateAndSaveKeyTo string
|
lastReadings map[string]PeerReading
|
||||||
lastReadings map[string]PeerReading
|
mu sync.Mutex
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||||
@@ -73,90 +70,60 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene
|
|||||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key := wgtypes.Key{}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
service := &WireGuardService{
|
|
||||||
interfaceName: interfaceName,
|
|
||||||
mtu: mtu,
|
|
||||||
client: wsClient,
|
|
||||||
wgClient: wgClient,
|
|
||||||
key: key,
|
|
||||||
reachableAt: reachableAt,
|
|
||||||
generateAndSaveKeyTo: generateAndSaveKeyTo,
|
|
||||||
lastReadings: make(map[string]PeerReading),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register websocket handlers
|
|
||||||
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
|
|
||||||
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
|
||||||
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
|
||||||
|
|
||||||
// Register connect handler to initiate configuration
|
|
||||||
wsClient.OnConnect(service.handleConnect)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) handleConnect() error {
|
|
||||||
logger.Debug("Public key: %s", s.key.PublicKey())
|
|
||||||
|
|
||||||
err := s.client.SendMessage("wg/register", map[string]interface{}{
|
|
||||||
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Sent registration message")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) Start() error {
|
|
||||||
|
|
||||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||||
if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
|
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||||
// generate a new private key
|
// generate a new private key
|
||||||
s.key, err = wgtypes.GeneratePrivateKey()
|
key, err = wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
logger.Fatal("Failed to generate private key: %v", err)
|
||||||
}
|
}
|
||||||
// save the key to the file
|
// save the key to the file
|
||||||
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
|
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to save private key: %v", err)
|
logger.Fatal("Failed to save private key: %v", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
|
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to read private key: %v", err)
|
logger.Fatal("Failed to read private key: %v", err)
|
||||||
}
|
}
|
||||||
s.key, err = wgtypes.ParseKey(string(keyData))
|
key, err = wgtypes.ParseKey(string(keyData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to parse private key: %v", err)
|
logger.Fatal("Failed to parse private key: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get initial configuration
|
service := &WireGuardService{
|
||||||
err := s.loadRemoteConfig()
|
interfaceName: interfaceName,
|
||||||
if err != nil {
|
mtu: mtu,
|
||||||
return fmt.Errorf("failed to load initial configuration: %v", err)
|
client: wsClient,
|
||||||
|
wgClient: wgClient,
|
||||||
|
key: key,
|
||||||
|
reachableAt: reachableAt,
|
||||||
|
lastReadings: make(map[string]PeerReading),
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
// Register websocket handlers
|
||||||
|
wsClient.RegisterHandler("wg/config/receive", service.handleConfig)
|
||||||
|
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
||||||
|
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
||||||
|
|
||||||
|
// Register connect handler to initiate configuration
|
||||||
|
wsClient.OnConnect(service.loadRemoteConfig)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) StartBandwidthReporting() {
|
func (s *WireGuardService) Close() {
|
||||||
go s.periodicBandwidthCheck()
|
s.client.Close()
|
||||||
|
wgClient.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) loadRemoteConfig() error {
|
func (s *WireGuardService) loadRemoteConfig() error {
|
||||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
||||||
|
|
||||||
// send a ws message to the server to get the config
|
go s.periodicBandwidthCheck()
|
||||||
|
|
||||||
err := s.client.SendMessage("wg/config/get", body)
|
err := s.client.SendMessage("wg/config/get", body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -255,9 +222,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.ensureMSSClamping(); err != nil {
|
// if err := s.ensureMSSClamping(); err != nil {
|
||||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||||
|
|
||||||
@@ -336,93 +303,93 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) ensureMSSClamping() error {
|
// func (s *WireGuardService) ensureMSSClamping() error {
|
||||||
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||||
mssValue := mtuInt - 40
|
// mssValue := mtuInt - 40
|
||||||
|
|
||||||
// Rules to be managed - just the chains, we'll construct the full command separately
|
// // Rules to be managed - just the chains, we'll construct the full command separately
|
||||||
chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
||||||
|
|
||||||
// First, try to delete any existing rules
|
// // First, try to delete any existing rules
|
||||||
for _, chain := range chains {
|
// for _, chain := range chains {
|
||||||
deleteCmd := exec.Command("/usr/sbin/iptables",
|
// deleteCmd := exec.Command("/usr/sbin/iptables",
|
||||||
"-t", "mangle",
|
// "-t", "mangle",
|
||||||
"-D", chain,
|
// "-D", chain,
|
||||||
"-p", "tcp",
|
// "-p", "tcp",
|
||||||
"--tcp-flags", "SYN,RST", "SYN",
|
// "--tcp-flags", "SYN,RST", "SYN",
|
||||||
"-j", "TCPMSS",
|
// "-j", "TCPMSS",
|
||||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||||
|
|
||||||
logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
||||||
|
|
||||||
// Try deletion multiple times to handle multiple existing rules
|
// // Try deletion multiple times to handle multiple existing rules
|
||||||
for i := 0; i < 3; i++ {
|
// for i := 0; i < 3; i++ {
|
||||||
out, err := deleteCmd.CombinedOutput()
|
// out, err := deleteCmd.CombinedOutput()
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
// Convert exit status 1 to string for better logging
|
// // Convert exit status 1 to string for better logging
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
// if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
||||||
chain, exitErr.String(), string(out))
|
// chain, exitErr.String(), string(out))
|
||||||
}
|
// }
|
||||||
break // No more rules to delete
|
// break // No more rules to delete
|
||||||
}
|
// }
|
||||||
logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Then add the new rules
|
// // Then add the new rules
|
||||||
var errors []error
|
// var errors []error
|
||||||
for _, chain := range chains {
|
// for _, chain := range chains {
|
||||||
addCmd := exec.Command("/usr/sbin/iptables",
|
// addCmd := exec.Command("/usr/sbin/iptables",
|
||||||
"-t", "mangle",
|
// "-t", "mangle",
|
||||||
"-A", chain,
|
// "-A", chain,
|
||||||
"-p", "tcp",
|
// "-p", "tcp",
|
||||||
"--tcp-flags", "SYN,RST", "SYN",
|
// "--tcp-flags", "SYN,RST", "SYN",
|
||||||
"-j", "TCPMSS",
|
// "-j", "TCPMSS",
|
||||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||||
|
|
||||||
logger.Info("Adding MSS clamping rule for chain %s", chain)
|
// logger.Info("Adding MSS clamping rule for chain %s", chain)
|
||||||
|
|
||||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
// if out, err := addCmd.CombinedOutput(); err != nil {
|
||||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||||
chain, err, string(out))
|
// chain, err, string(out))
|
||||||
logger.Error(errMsg)
|
// logger.Error(errMsg)
|
||||||
errors = append(errors, fmt.Errorf(errMsg))
|
// errors = append(errors, fmt.Errorf(errMsg))
|
||||||
continue
|
// continue
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Verify the rule was added
|
// // Verify the rule was added
|
||||||
checkCmd := exec.Command("/usr/sbin/iptables",
|
// checkCmd := exec.Command("/usr/sbin/iptables",
|
||||||
"-t", "mangle",
|
// "-t", "mangle",
|
||||||
"-C", chain,
|
// "-C", chain,
|
||||||
"-p", "tcp",
|
// "-p", "tcp",
|
||||||
"--tcp-flags", "SYN,RST", "SYN",
|
// "--tcp-flags", "SYN,RST", "SYN",
|
||||||
"-j", "TCPMSS",
|
// "-j", "TCPMSS",
|
||||||
"--set-mss", fmt.Sprintf("%d", mssValue))
|
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||||
|
|
||||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
// if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||||
chain, err, string(out))
|
// chain, err, string(out))
|
||||||
logger.Error(errMsg)
|
// logger.Error(errMsg)
|
||||||
errors = append(errors, fmt.Errorf(errMsg))
|
// errors = append(errors, fmt.Errorf(errMsg))
|
||||||
continue
|
// continue
|
||||||
}
|
// }
|
||||||
|
|
||||||
logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// If we encountered any errors, return them combined
|
// // If we encountered any errors, return them combined
|
||||||
if len(errors) > 0 {
|
// if len(errors) > 0 {
|
||||||
var errMsgs []string
|
// var errMsgs []string
|
||||||
for _, err := range errors {
|
// for _, err := range errors {
|
||||||
errMsgs = append(errMsgs, err.Error())
|
// errMsgs = append(errMsgs, err.Error())
|
||||||
}
|
// }
|
||||||
return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
||||||
strings.Join(errMsgs, "\n"))
|
// strings.Join(errMsgs, "\n"))
|
||||||
}
|
// }
|
||||||
|
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||||
var peer Peer
|
var peer Peer
|
||||||
|
|||||||
Reference in New Issue
Block a user