From b9261b8fea620d5254ecffee996d0ef88d4c9b1b Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 27 Feb 2026 15:45:17 -0800 Subject: [PATCH] Add optional tc --- main.go | 228 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 207 insertions(+), 21 deletions(-) diff --git a/main.go b/main.go index b723530..dde0026 100644 --- a/main.go +++ b/main.go @@ -33,15 +33,16 @@ import ( ) var ( - interfaceName string - listenAddr string - mtuInt int - lastReadings = make(map[string]PeerReading) - mu sync.Mutex - wgMu sync.Mutex // Protects WireGuard operations - notifyURL string - proxyRelay *relay.UDPProxyServer - proxySNI *proxy.SNIProxy + interfaceName string + listenAddr string + mtuInt int + lastReadings = make(map[string]PeerReading) + mu sync.Mutex + wgMu sync.Mutex // Protects WireGuard operations + notifyURL string + proxyRelay *relay.UDPProxyServer + proxySNI *proxy.SNIProxy + doTrafficShaping bool ) type WgConfig struct { @@ -151,6 +152,7 @@ func main() { localOverridesStr = os.Getenv("LOCAL_OVERRIDES") trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS") proxyProtocolStr := os.Getenv("PROXY_PROTOCOL") + doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING") if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -222,6 +224,13 @@ func main() { flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP") } + if doTrafficShapingStr != "" { + doTrafficShaping = strings.ToLower(doTrafficShapingStr) == "true" + } + if doTrafficShapingStr == "" { + flag.BoolVar(&doTrafficShaping, "do-traffic-shaping", false, "Whether to set up traffic shaping rules for peers (requires tc command and root privileges)") + } + flag.Parse() logger.Init() @@ -886,17 +895,23 @@ func addPeerInternal(peer Peer) error { return fmt.Errorf("failed to parse public key: %v", err) } + logger.Debug("Adding peer %s with AllowedIPs: %v", peer.PublicKey, peer.AllowedIPs) + // parse allowed IPs into array of net.IPNet var allowedIPs []net.IPNet var wgIPs []string for _, ipStr := range peer.AllowedIPs { + logger.Debug("Parsing AllowedIP: %s", ipStr) _, ipNet, err := net.ParseCIDR(ipStr) if err != nil { + logger.Warn("Failed to parse allowed IP '%s' for peer %s: %v", ipStr, peer.PublicKey, err) return fmt.Errorf("failed to parse allowed IP: %v", err) } allowedIPs = append(allowedIPs, *ipNet) // Extract the IP address from the CIDR for relay cleanup - wgIPs = append(wgIPs, ipNet.IP.String()) + extractedIP := ipNet.IP.String() + wgIPs = append(wgIPs, extractedIP) + logger.Debug("Extracted IP %s from AllowedIP %s", extractedIP, ipStr) } peerConfig := wgtypes.PeerConfig{ @@ -912,6 +927,18 @@ func addPeerInternal(peer Peer) error { return fmt.Errorf("failed to add peer: %v", err) } + // Setup bandwidth limiting for each peer IP + if doTrafficShaping { + logger.Debug("doTrafficShaping is true, setting up bandwidth limits for %d IPs", len(wgIPs)) + for _, wgIP := range wgIPs { + if err := setupPeerBandwidthLimit(wgIP); err != nil { + logger.Warn("Failed to setup bandwidth limit for peer IP %s: %v", wgIP, err) + } + } + } else { + logger.Debug("doTrafficShaping is false, skipping bandwidth limit setup") + } + // Clear relay connections for the peer's WireGuard IPs if proxyRelay != nil { for _, wgIP := range wgIPs { @@ -956,19 +983,17 @@ func removePeerInternal(publicKey string) error { return fmt.Errorf("failed to parse public key: %v", err) } - // Get current peer info before removing to clear relay connections + // Get current peer info before removing to clear relay connections and bandwidth limits var wgIPs []string - if proxyRelay != nil { - device, err := wgClient.Device(interfaceName) - if err == nil { - for _, peer := range device.Peers { - if peer.PublicKey.String() == publicKey { - // Extract WireGuard IPs from this peer's allowed IPs - for _, allowedIP := range peer.AllowedIPs { - wgIPs = append(wgIPs, allowedIP.IP.String()) - } - break + device, err := wgClient.Device(interfaceName) + if err == nil { + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + // Extract WireGuard IPs from this peer's allowed IPs + for _, allowedIP := range peer.AllowedIPs { + wgIPs = append(wgIPs, allowedIP.IP.String()) } + break } } } @@ -986,6 +1011,15 @@ func removePeerInternal(publicKey string) error { return fmt.Errorf("failed to remove peer: %v", err) } + // Remove bandwidth limits for each peer IP + if doTrafficShaping { + for _, wgIP := range wgIPs { + if err := removePeerBandwidthLimit(wgIP); err != nil { + logger.Warn("Failed to remove bandwidth limit for peer IP %s: %v", wgIP, err) + } + } + } + // Clear relay connections for the peer's WireGuard IPs if proxyRelay != nil { for _, wgIP := range wgIPs { @@ -1315,3 +1349,155 @@ func monitorMemory(limit uint64) { time.Sleep(5 * time.Second) } } + +// setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP +// Currently hardcoded to 20 Mbps per peer +func setupPeerBandwidthLimit(peerIP string) error { + logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP) + const bandwidthLimit = "50mbit" // 50 Mbps limit per peer + + // Parse the IP to get just the IP address (strip any CIDR notation if present) + ip := peerIP + if strings.Contains(peerIP, "/") { + parsedIP, _, err := net.ParseCIDR(peerIP) + if err != nil { + return fmt.Errorf("failed to parse peer IP: %v", err) + } + ip = parsedIP.String() + } + + // First, ensure we have a root qdisc on the interface (HTB - Hierarchical Token Bucket) + // Check if qdisc already exists + cmd := exec.Command("tc", "qdisc", "show", "dev", interfaceName) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to check qdisc: %v, output: %s", err, string(output)) + } + + // If no HTB qdisc exists, create one + if !strings.Contains(string(output), "htb") { + cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "root", "handle", "1:", "htb", "default", "9999") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add root qdisc: %v, output: %s", err, string(output)) + } + logger.Info("Created HTB root qdisc on %s", interfaceName) + } + + // Generate a unique class ID based on the IP address + // We'll use the last octet of the IP as part of the class ID + ipParts := strings.Split(ip, ".") + if len(ipParts) != 4 { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + lastOctet := ipParts[3] + classID := fmt.Sprintf("1:%s", lastOctet) + logger.Debug("Generated class ID %s for peer IP %s", classID, ip) + + // Create a class for this peer with bandwidth limit + cmd = exec.Command("tc", "class", "add", "dev", interfaceName, "parent", "1:", "classid", classID, + "htb", "rate", bandwidthLimit, "ceil", bandwidthLimit) + if output, err := cmd.CombinedOutput(); err != nil { + logger.Debug("tc class add failed for %s: %v, output: %s", ip, err, string(output)) + // If class already exists, try to replace it + if strings.Contains(string(output), "File exists") { + cmd = exec.Command("tc", "class", "replace", "dev", interfaceName, "parent", "1:", "classid", classID, + "htb", "rate", bandwidthLimit, "ceil", bandwidthLimit) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to replace class: %v, output: %s", err, string(output)) + } + logger.Debug("Successfully replaced existing class %s for peer IP %s", classID, ip) + } else { + return fmt.Errorf("failed to add class: %v, output: %s", err, string(output)) + } + } else { + logger.Debug("Successfully added new class %s for peer IP %s", classID, ip) + } + + // Add a filter to match traffic from this peer IP (ingress) + cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:", + "prio", "1", "u32", "match", "ip", "src", ip, "flowid", classID) + if output, err := cmd.CombinedOutput(); err != nil { + // If filter fails, log but don't fail the peer addition + logger.Warn("Failed to add ingress filter for peer IP %s: %v, output: %s", ip, err, string(output)) + } + + // Add a filter to match traffic to this peer IP (egress) + cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:", + "prio", "1", "u32", "match", "ip", "dst", ip, "flowid", classID) + if output, err := cmd.CombinedOutput(); err != nil { + // If filter fails, log but don't fail the peer addition + logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output)) + } + + logger.Info("Setup bandwidth limit of %s for peer IP %s (class %s)", bandwidthLimit, ip, classID) + return nil +} + +// removePeerBandwidthLimit removes TC rules for a specific peer IP +func removePeerBandwidthLimit(peerIP string) error { + // Parse the IP to get just the IP address + ip := peerIP + if strings.Contains(peerIP, "/") { + parsedIP, _, err := net.ParseCIDR(peerIP) + if err != nil { + return fmt.Errorf("failed to parse peer IP: %v", err) + } + ip = parsedIP.String() + } + + // Generate the class ID based on the IP + ipParts := strings.Split(ip, ".") + if len(ipParts) != 4 { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + lastOctet := ipParts[3] + classID := fmt.Sprintf("1:%s", lastOctet) + + // Remove filters for this IP + // List all filters to find the ones for this class + cmd := exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "1:") + output, err := cmd.CombinedOutput() + if err != nil { + logger.Warn("Failed to list filters for peer IP %s: %v, output: %s", ip, err, string(output)) + } else { + // Parse the output to find filter handles that match this classID + // The output format includes lines like: + // filter parent 1: protocol ip pref 1 u32 chain 0 fh 800::800 order 2048 key ht 800 bkt 0 flowid 1:4 + lines := strings.Split(string(output), "\n") + for _, line := range lines { + // Look for lines containing our flowid (classID) + if strings.Contains(line, "flowid "+classID) && strings.Contains(line, "fh ") { + // Extract handle (format: fh 800::800) + parts := strings.Fields(line) + var handle string + for j, part := range parts { + if part == "fh" && j+1 < len(parts) { + handle = parts[j+1] + break + } + } + if handle != "" { + // Delete this filter using the handle + delCmd := exec.Command("tc", "filter", "del", "dev", interfaceName, "parent", "1:", "handle", handle, "prio", "1", "u32") + if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil { + logger.Debug("Failed to delete filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput)) + } else { + logger.Debug("Deleted filter handle %s for peer IP %s", handle, ip) + } + } + } + } + } + + // Remove the class + cmd = exec.Command("tc", "class", "del", "dev", interfaceName, "classid", classID) + if output, err := cmd.CombinedOutput(); err != nil { + // It's okay if the class doesn't exist + if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") { + logger.Warn("Failed to remove class for peer IP %s: %v, output: %s", ip, err, string(output)) + } + } + + logger.Info("Removed bandwidth limit for peer IP %s (class %s)", ip, classID) + return nil +}