mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Move wg into more of a class
This commit is contained in:
49
main.go
49
main.go
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/proxy"
|
"github.com/fosrl/newt/proxy"
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/newt/websocket"
|
||||||
|
"github.com/fosrl/newt/wg"
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -246,15 +247,18 @@ func resolveDomain(domain string) (string, error) {
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
endpoint string
|
endpoint string
|
||||||
id string
|
id string
|
||||||
secret string
|
secret string
|
||||||
mtu string
|
mtu string
|
||||||
mtuInt int
|
mtuInt int
|
||||||
dns string
|
dns string
|
||||||
privateKey wgtypes.Key
|
privateKey wgtypes.Key
|
||||||
err error
|
err error
|
||||||
logLevel string
|
logLevel string
|
||||||
|
interfaceName string
|
||||||
|
generateAndSaveKeyTo string
|
||||||
|
reachableAt string
|
||||||
)
|
)
|
||||||
|
|
||||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
||||||
@@ -264,6 +268,9 @@ func main() {
|
|||||||
mtu = os.Getenv("MTU")
|
mtu = os.Getenv("MTU")
|
||||||
dns = os.Getenv("DNS")
|
dns = os.Getenv("DNS")
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
logLevel = os.Getenv("LOG_LEVEL")
|
||||||
|
interfaceName = os.Getenv("INTERFACE")
|
||||||
|
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
||||||
|
reachableAt = os.Getenv("REACHABLE_AT")
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||||
@@ -283,6 +290,15 @@ func main() {
|
|||||||
if logLevel == "" {
|
if logLevel == "" {
|
||||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
||||||
}
|
}
|
||||||
|
if interfaceName == "" {
|
||||||
|
flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface")
|
||||||
|
}
|
||||||
|
if generateAndSaveKeyTo == "" {
|
||||||
|
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||||
|
}
|
||||||
|
if reachableAt == "" {
|
||||||
|
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
||||||
|
}
|
||||||
|
|
||||||
// do a --version check
|
// do a --version check
|
||||||
version := flag.Bool("version", false, "Print the version")
|
version := flag.Bool("version", false, "Print the version")
|
||||||
@@ -319,6 +335,21 @@ func main() {
|
|||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create WireGuard service
|
||||||
|
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||||
|
}
|
||||||
|
// defer wgService.Close()
|
||||||
|
|
||||||
|
// Start the WireGuard service
|
||||||
|
if err := wgService.Start(); err != nil {
|
||||||
|
logger.Fatal("Failed to start WireGuard service: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start bandwidth reporting
|
||||||
|
wgService.StartBandwidthReporting()
|
||||||
|
|
||||||
// Create TUN device and network stack
|
// Create TUN device and network stack
|
||||||
var tun tun.Device
|
var tun tun.Device
|
||||||
var tnet *netstack.Net
|
var tnet *netstack.Net
|
||||||
|
|||||||
295
wg/wg.go
295
wg/wg.go
@@ -3,14 +3,10 @@ package wg
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -58,147 +54,149 @@ var (
|
|||||||
wgClient *wgctrl.Client
|
wgClient *wgctrl.Client
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
type WireGuardService struct {
|
||||||
var (
|
interfaceName string
|
||||||
err error
|
mtu int
|
||||||
wgconfig WgConfig
|
client *websocket.Client
|
||||||
remoteConfigURL string
|
wgClient *wgctrl.Client
|
||||||
generateAndSaveKeyTo string
|
config WgConfig
|
||||||
reachableAt string
|
key wgtypes.Key
|
||||||
logLevel string
|
reachableAt string
|
||||||
mtu string
|
generateAndSaveKeyTo string
|
||||||
)
|
lastReadings map[string]PeerReading
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
interfaceName = os.Getenv("INTERFACE")
|
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||||
remoteConfigURL = os.Getenv("REMOTE_CONFIG")
|
wgClient, err := wgctrl.New()
|
||||||
listenAddr = os.Getenv("LISTEN")
|
|
||||||
generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO")
|
|
||||||
reachableAt = os.Getenv("REACHABLE_AT")
|
|
||||||
logLevel = os.Getenv("LOG_LEVEL")
|
|
||||||
mtu = os.Getenv("MTU")
|
|
||||||
|
|
||||||
if interfaceName == "" {
|
|
||||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
|
||||||
}
|
|
||||||
if remoteConfigURL == "" {
|
|
||||||
flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration")
|
|
||||||
}
|
|
||||||
if generateAndSaveKeyTo == "" {
|
|
||||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
|
||||||
}
|
|
||||||
if reachableAt == "" {
|
|
||||||
flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about")
|
|
||||||
}
|
|
||||||
if logLevel == "" {
|
|
||||||
flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)")
|
|
||||||
}
|
|
||||||
if mtu == "" {
|
|
||||||
flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface")
|
|
||||||
}
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
mtuInt, err = strconv.Atoi(mtu)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to parse MTU: %v", err)
|
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var key wgtypes.Key
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
service := &WireGuardService{
|
||||||
|
interfaceName: interfaceName,
|
||||||
|
mtu: mtu,
|
||||||
|
client: wsClient,
|
||||||
|
wgClient: wgClient,
|
||||||
|
key: key,
|
||||||
|
reachableAt: reachableAt,
|
||||||
|
generateAndSaveKeyTo: generateAndSaveKeyTo,
|
||||||
|
lastReadings: make(map[string]PeerReading),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register websocket handlers
|
||||||
|
wsClient.RegisterHandler("wg/peer/config", service.handleConfig)
|
||||||
|
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
||||||
|
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
||||||
|
|
||||||
|
// Register connect handler to initiate configuration
|
||||||
|
wsClient.OnConnect(service.handleConnect)
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) handleConnect() error {
|
||||||
|
logger.Debug("Public key: %s", s.key.PublicKey())
|
||||||
|
|
||||||
|
err := s.client.SendMessage("wg/register", map[string]interface{}{
|
||||||
|
"publicKey": fmt.Sprintf("%s", s.key.PublicKey()),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Sent registration message")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) Start() error {
|
||||||
|
|
||||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||||
if generateAndSaveKeyTo != "" {
|
if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
// generate a new private key
|
||||||
// generate a new private key
|
s.key, err = wgtypes.GeneratePrivateKey()
|
||||||
key, err = wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
// save the key to the file
|
|
||||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to save private key: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to read private key: %v", err)
|
|
||||||
}
|
|
||||||
key, err = wgtypes.ParseKey(string(keyData))
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to parse private key: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// if no generateAndSaveKeyTo is provided, ensure that the private key is provided
|
|
||||||
if wgconfig.PrivateKey == "" {
|
|
||||||
// generate a new one
|
|
||||||
key, err = wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loop until we get the config
|
|
||||||
for wgconfig.PrivateKey == "" {
|
|
||||||
logger.Info("Fetching remote config from %s", remoteConfigURL)
|
|
||||||
wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to load configuration: %v", err)
|
logger.Fatal("Failed to generate private key: %v", err)
|
||||||
time.Sleep(5 * time.Second)
|
}
|
||||||
continue
|
// save the key to the file
|
||||||
|
err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to save private key: %v", err)
|
||||||
}
|
}
|
||||||
wgconfig.PrivateKey = key.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
wgClient, err = wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create WireGuard client: %v", err)
|
|
||||||
}
|
|
||||||
defer wgClient.Close()
|
|
||||||
|
|
||||||
// Ensure the WireGuard interface exists and is configured
|
|
||||||
if err := ensureWireguardInterface(wgconfig); err != nil {
|
|
||||||
logger.Fatal("Failed to ensure WireGuard interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the WireGuard peers exist
|
|
||||||
ensureWireguardPeers(wgconfig.Peers)
|
|
||||||
|
|
||||||
// go periodicBandwidthCheck(reportBandwidthTo)
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
|
||||||
var body *bytes.Buffer
|
|
||||||
if reachableAt == "" {
|
|
||||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
|
|
||||||
} else {
|
} else {
|
||||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt)))
|
keyData, err := os.ReadFile(s.generateAndSaveKeyTo)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to read private key: %v", err)
|
||||||
|
}
|
||||||
|
s.key, err = wgtypes.ParseKey(string(keyData))
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to parse private key: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
resp, err := http.Post(url, "application/json", body)
|
|
||||||
|
// Get initial configuration
|
||||||
|
err := s.loadRemoteConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// print the error
|
return fmt.Errorf("failed to load initial configuration: %v", err)
|
||||||
logger.Error("Error fetching remote config %s: %v", url, err)
|
|
||||||
return WgConfig{}, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
data, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return WgConfig{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var config WgConfig
|
return nil
|
||||||
err = json.Unmarshal(data, &config)
|
|
||||||
|
|
||||||
return config, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureWireguardInterface(wgconfig WgConfig) error {
|
func (s *WireGuardService) StartBandwidthReporting() {
|
||||||
|
go s.periodicBandwidthCheck()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) loadRemoteConfig() error {
|
||||||
|
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
||||||
|
|
||||||
|
// send a ws message to the server to get the config
|
||||||
|
|
||||||
|
err := s.client.SendMessage("wg/config/get", body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send config request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||||
|
var config WgConfig
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Info("Error marshaling data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &config); err != nil {
|
||||||
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.config = config
|
||||||
|
|
||||||
|
// Ensure the WireGuard interface and peers are configured
|
||||||
|
if err := s.ensureWireguardInterface(config); err != nil {
|
||||||
|
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||||
|
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||||
// Check if the WireGuard interface exists
|
// Check if the WireGuard interface exists
|
||||||
_, err := netlink.LinkByName(interfaceName)
|
_, err := netlink.LinkByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||||
// Interface doesn't exist, so create it
|
// Interface doesn't exist, so create it
|
||||||
err = createWireGuardInterface()
|
err = s.createWireGuardInterface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||||
}
|
}
|
||||||
@@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Assign IP address to the interface
|
// Assign IP address to the interface
|
||||||
err = assignIPAddress(wgconfig.IpAddress)
|
err = s.assignIPAddress(wgconfig.IpAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to assign IP address: %v", err)
|
logger.Fatal("Failed to assign IP address: %v", err)
|
||||||
}
|
}
|
||||||
@@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ensureMSSClamping(); err != nil {
|
if err := s.ensureMSSClamping(); err != nil {
|
||||||
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createWireGuardInterface() error {
|
func (s *WireGuardService) createWireGuardInterface() error {
|
||||||
wgLink := &netlink.GenericLink{
|
wgLink := &netlink.GenericLink{
|
||||||
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
||||||
LinkType: "wireguard",
|
LinkType: "wireguard",
|
||||||
@@ -274,7 +272,7 @@ func createWireGuardInterface() error {
|
|||||||
return netlink.LinkAdd(wgLink)
|
return netlink.LinkAdd(wgLink)
|
||||||
}
|
}
|
||||||
|
|
||||||
func assignIPAddress(ipAddress string) error {
|
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||||
link, err := netlink.LinkByName(interfaceName)
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get interface: %v", err)
|
return fmt.Errorf("failed to get interface: %v", err)
|
||||||
@@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error {
|
|||||||
return netlink.AddrAdd(link, addr)
|
return netlink.AddrAdd(link, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureWireguardPeers(peers []Peer) error {
|
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||||
// get the current peers
|
// get the current peers
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := wgClient.Device(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
err := removePeer(peer)
|
err := s.removePeer(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to remove peer: %v", err)
|
return fmt.Errorf("failed to remove peer: %v", err)
|
||||||
}
|
}
|
||||||
@@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
err := addPeer(configPeer)
|
err := s.addPeer(configPeer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add peer: %v", err)
|
return fmt.Errorf("failed to add peer: %v", err)
|
||||||
}
|
}
|
||||||
@@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureMSSClamping() error {
|
func (s *WireGuardService) ensureMSSClamping() error {
|
||||||
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
// Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||||
mssValue := mtuInt - 40
|
mssValue := mtuInt - 40
|
||||||
|
|
||||||
@@ -426,7 +424,7 @@ func ensureMSSClamping() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleAddPeer(msg websocket.WSMessage) {
|
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||||
var peer Peer
|
var peer Peer
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
@@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) {
|
|||||||
logger.Info("Error unmarshaling target data: %v", err)
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addPeer(peer)
|
err = s.addPeer(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func addPeer(peer Peer) error {
|
func (s *WireGuardService) addPeer(peer Peer) error {
|
||||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
return fmt.Errorf("failed to parse public key: %v", err)
|
||||||
@@ -478,7 +476,7 @@ func addPeer(peer Peer) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleRemovePeer(msg websocket.WSMessage) {
|
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||||
type RemoveRequest struct {
|
type RemoveRequest struct {
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
@@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := removePeer(request.PublicKey); err != nil {
|
if err := s.removePeer(request.PublicKey); err != nil {
|
||||||
logger.Info("Error removing peer: %v", err)
|
logger.Info("Error removing peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func removePeer(publicKey string) error {
|
func (s *WireGuardService) removePeer(publicKey string) error {
|
||||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
return fmt.Errorf("failed to parse public key: %v", err)
|
||||||
@@ -525,18 +523,18 @@ func removePeer(publicKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func periodicBandwidthCheck(endpoint string) {
|
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
if err := s.reportPeerBandwidth(); err != nil {
|
||||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||||
device, err := wgClient.Device(interfaceName)
|
device, err := wgClient.Device(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||||
@@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
|||||||
return peerBandwidths, nil
|
return peerBandwidths, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reportPeerBandwidth(apiURL string) error {
|
func (s *WireGuardService) reportPeerBandwidth() error {
|
||||||
bandwidths, err := calculatePeerBandwidth()
|
bandwidths, err := s.calculatePeerBandwidth()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||||
}
|
}
|
||||||
@@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error {
|
|||||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
|
err = s.client.SendMessage("wg/bandwidth", jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user