diff --git a/go.mod b/go.mod index 2cc0c19..7812b1b 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,21 @@ require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 require ( github.com/google/btree v1.1.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect - golang.org/x/crypto v0.28.0 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect + github.com/vishvananda/netlink v1.3.0 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 // indirect gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) diff --git a/go.sum b/go.sum index d95ab3a..f453d4f 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,39 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= @@ -18,5 +42,7 @@ golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uI golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/wg/wg.go b/wg/wg.go new file mode 100644 index 0000000..ed868ea --- /dev/null +++ b/wg/wg.go @@ -0,0 +1,646 @@ +package wg + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "strconv" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/websocket" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var ( + interfaceName string + listenAddr string + mtuInt int + lastReadings = make(map[string]PeerReading) + mu sync.Mutex +) + +type WgConfig struct { + PrivateKey string `json:"privateKey"` + ListenPort int `json:"listenPort"` + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +type PeerReading struct { + BytesReceived int64 + BytesTransmitted int64 + LastChecked time.Time +} + +var ( + wgClient *wgctrl.Client +) + +func main() { + var ( + err error + wgconfig WgConfig + remoteConfigURL string + generateAndSaveKeyTo string + reachableAt string + logLevel string + mtu string + ) + + interfaceName = os.Getenv("INTERFACE") + remoteConfigURL = os.Getenv("REMOTE_CONFIG") + listenAddr = os.Getenv("LISTEN") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + reachableAt = os.Getenv("REACHABLE_AT") + logLevel = os.Getenv("LOG_LEVEL") + mtu = os.Getenv("MTU") + + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") + } + if remoteConfigURL == "" { + flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") + } + if reachableAt == "" { + flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") + } + if logLevel == "" { + flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + } + if mtu == "" { + flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface") + } + flag.Parse() + + mtuInt, err = strconv.Atoi(mtu) + if err != nil { + logger.Fatal("Failed to parse MTU: %v", err) + } + + var key wgtypes.Key + // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file + if generateAndSaveKeyTo != "" { + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + // generate a new private key + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Fatal("Failed to generate private key: %v", err) + } + // save the key to the file + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + logger.Fatal("Failed to save private key: %v", err) + } + } else { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + logger.Fatal("Failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(string(keyData)) + if err != nil { + logger.Fatal("Failed to parse private key: %v", err) + } + } + } else { + // if no generateAndSaveKeyTo is provided, ensure that the private key is provided + if wgconfig.PrivateKey == "" { + // generate a new one + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Fatal("Failed to generate private key: %v", err) + } + } + } + + // loop until we get the config + for wgconfig.PrivateKey == "" { + logger.Info("Fetching remote config from %s", remoteConfigURL) + wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt) + if err != nil { + logger.Error("Failed to load configuration: %v", err) + time.Sleep(5 * time.Second) + continue + } + wgconfig.PrivateKey = key.String() + } + + wgClient, err = wgctrl.New() + if err != nil { + logger.Fatal("Failed to create WireGuard client: %v", err) + } + defer wgClient.Close() + + // Ensure the WireGuard interface exists and is configured + if err := ensureWireguardInterface(wgconfig); err != nil { + logger.Fatal("Failed to ensure WireGuard interface: %v", err) + } + + // Ensure the WireGuard peers exist + ensureWireguardPeers(wgconfig.Peers) + + // go periodicBandwidthCheck(reportBandwidthTo) +} + +func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { + var body *bytes.Buffer + if reachableAt == "" { + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String()))) + } else { + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) + } + resp, err := http.Post(url, "application/json", body) + if err != nil { + // print the error + logger.Error("Error fetching remote config %s: %v", url, err) + return WgConfig{}, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return WgConfig{}, err + } + + var config WgConfig + err = json.Unmarshal(data, &config) + + return config, err +} + +func ensureWireguardInterface(wgconfig WgConfig) error { + // Check if the WireGuard interface exists + _, err := netlink.LinkByName(interfaceName) + if err != nil { + if _, ok := err.(netlink.LinkNotFoundError); ok { + // Interface doesn't exist, so create it + err = createWireGuardInterface() + if err != nil { + logger.Fatal("Failed to create WireGuard interface: %v", err) + } + logger.Info("Created WireGuard interface %s\n", interfaceName) + } else { + logger.Fatal("Error checking for WireGuard interface: %v", err) + } + } else { + logger.Info("WireGuard interface %s already exists\n", interfaceName) + return nil + } + + // Assign IP address to the interface + err = assignIPAddress(wgconfig.IpAddress) + if err != nil { + logger.Fatal("Failed to assign IP address: %v", err) + } + logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) + + // Check if the interface already exists + _, err = wgClient.Device(interfaceName) + if err != nil { + return fmt.Errorf("interface %s does not exist", interfaceName) + } + + // Parse the private key + key, err := wgtypes.ParseKey(wgconfig.PrivateKey) + if err != nil { + return fmt.Errorf("failed to parse private key: %v", err) + } + + // Create a new WireGuard configuration + config := wgtypes.Config{ + PrivateKey: &key, + ListenPort: new(int), + } + *config.ListenPort = wgconfig.ListenPort + + // Create and configure the WireGuard interface + err = wgClient.ConfigureDevice(interfaceName, config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // bring up the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + if err := netlink.LinkSetMTU(link, mtuInt); err != nil { + return fmt.Errorf("failed to set MTU: %v", err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + if err := ensureMSSClamping(); err != nil { + logger.Warn("Failed to ensure MSS clamping: %v", err) + } + + logger.Info("WireGuard interface %s created and configured", interfaceName) + + return nil +} + +func createWireGuardInterface() error { + wgLink := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, + LinkType: "wireguard", + } + return netlink.LinkAdd(wgLink) +} + +func assignIPAddress(ipAddress string) error { + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + addr, err := netlink.ParseAddr(ipAddress) + if err != nil { + return fmt.Errorf("failed to parse IP address: %v", err) + } + + return netlink.AddrAdd(link, addr) +} + +func ensureWireguardPeers(peers []Peer) error { + // get the current peers + device, err := wgClient.Device(interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the peer public keys + var currentPeers []string + for _, peer := range device.Peers { + currentPeers = append(currentPeers, peer.PublicKey.String()) + } + + // remove any peers that are not in the config + for _, peer := range currentPeers { + found := false + for _, configPeer := range peers { + if peer == configPeer.PublicKey { + found = true + break + } + } + if !found { + err := removePeer(peer) + if err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + } + } + + // add any peers that are in the config but not in the current peers + for _, configPeer := range peers { + found := false + for _, peer := range currentPeers { + if configPeer.PublicKey == peer { + found = true + break + } + } + if !found { + err := addPeer(configPeer) + if err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + } + + return nil +} + +func ensureMSSClamping() error { + // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) + mssValue := mtuInt - 40 + + // Rules to be managed - just the chains, we'll construct the full command separately + chains := []string{"INPUT", "OUTPUT", "FORWARD"} + + // First, try to delete any existing rules + for _, chain := range chains { + deleteCmd := exec.Command("/usr/sbin/iptables", + "-t", "mangle", + "-D", chain, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mssValue)) + + logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) + + // Try deletion multiple times to handle multiple existing rules + for i := 0; i < 3; i++ { + out, err := deleteCmd.CombinedOutput() + if err != nil { + // Convert exit status 1 to string for better logging + if exitErr, ok := err.(*exec.ExitError); ok { + logger.Debug("Deletion stopped for chain %s: %v (output: %s)", + chain, exitErr.String(), string(out)) + } + break // No more rules to delete + } + logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) + } + } + + // Then add the new rules + var errors []error + for _, chain := range chains { + addCmd := exec.Command("/usr/sbin/iptables", + "-t", "mangle", + "-A", chain, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mssValue)) + + logger.Info("Adding MSS clamping rule for chain %s", chain) + + if out, err := addCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", + chain, err, string(out)) + logger.Error(errMsg) + errors = append(errors, fmt.Errorf(errMsg)) + continue + } + + // Verify the rule was added + checkCmd := exec.Command("/usr/sbin/iptables", + "-t", "mangle", + "-C", chain, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mssValue)) + + if out, err := checkCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", + chain, err, string(out)) + logger.Error(errMsg) + errors = append(errors, fmt.Errorf(errMsg)) + continue + } + + logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) + } + + // If we encountered any errors, return them combined + if len(errors) > 0 { + var errMsgs []string + for _, err := range errors { + errMsgs = append(errMsgs, err.Error()) + } + return fmt.Errorf("MSS clamping setup encountered errors:\n%s", + strings.Join(errMsgs, "\n")) + } + + return nil +} + +func handleAddPeer(msg websocket.WSMessage) { + var peer Peer + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + if err := json.Unmarshal(jsonData, &peer); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + } + + err = addPeer(peer) + if err != nil { + return + } +} + +func addPeer(peer Peer) error { + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // parse allowed IPs into array of net.IPNet + var allowedIPs []net.IPNet + for _, ipStr := range peer.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + return fmt.Errorf("failed to parse allowed IP: %v", err) + } + allowedIPs = append(allowedIPs, *ipNet) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + + logger.Info("Peer %s added successfully", peer.PublicKey) + + return nil +} + +func handleRemovePeer(msg websocket.WSMessage) { + // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } + type RemoveRequest struct { + PublicKey string `json:"publicKey"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + var request RemoveRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling data: %v", err) + return + } + + if err := removePeer(request.PublicKey); err != nil { + logger.Info("Error removing peer: %v", err) + return + } +} + +func removePeer(publicKey string) error { + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + logger.Info("Peer %s removed successfully", publicKey) + + return nil +} + +func periodicBandwidthCheck(endpoint string) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := reportPeerBandwidth(endpoint); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + } +} + +func calculatePeerBandwidth() ([]PeerBandwidth, error) { + device, err := wgClient.Device(interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get device: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + now := time.Now() + + mu.Lock() + defer mu.Unlock() + + for _, peer := range device.Peers { + publicKey := peer.PublicKey.String() + currentReading := PeerReading{ + BytesReceived: peer.ReceiveBytes, + BytesTransmitted: peer.TransmitBytes, + LastChecked: now, + } + + var bytesInDiff, bytesOutDiff float64 + lastReading, exists := lastReadings[publicKey] + + if exists { + timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() + if timeDiff > 0 { + // Calculate bytes transferred since last reading + bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + + // Handle counter wraparound (if the counter resets or overflows) + if bytesInDiff < 0 { + bytesInDiff = float64(currentReading.BytesReceived) + } + if bytesOutDiff < 0 { + bytesOutDiff = float64(currentReading.BytesTransmitted) + } + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + }) + } else { + // If readings are too close together or time hasn't passed, report 0 + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + } else { + // For first reading of a peer, report 0 to establish baseline + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + + // Update the last reading + lastReadings[publicKey] = currentReading + } + + // Clean up old peers + for publicKey := range lastReadings { + found := false + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + found = true + break + } + } + if !found { + delete(lastReadings, publicKey) + } + } + + return peerBandwidths, nil +} + +func reportPeerBandwidth(apiURL string) error { + bandwidths, err := calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + jsonData, err := json.Marshal(bandwidths) + if err != nil { + return fmt.Errorf("failed to marshal bandwidth data: %v", err) + } + + resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("API returned non-OK status: %s", resp.Status) + } + + return nil +}