mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-03 09:16:45 +00:00
Add optional tc
This commit is contained in:
228
main.go
228
main.go
@@ -33,15 +33,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
interfaceName string
|
interfaceName string
|
||||||
listenAddr string
|
listenAddr string
|
||||||
mtuInt int
|
mtuInt int
|
||||||
lastReadings = make(map[string]PeerReading)
|
lastReadings = make(map[string]PeerReading)
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
wgMu sync.Mutex // Protects WireGuard operations
|
wgMu sync.Mutex // Protects WireGuard operations
|
||||||
notifyURL string
|
notifyURL string
|
||||||
proxyRelay *relay.UDPProxyServer
|
proxyRelay *relay.UDPProxyServer
|
||||||
proxySNI *proxy.SNIProxy
|
proxySNI *proxy.SNIProxy
|
||||||
|
doTrafficShaping bool
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
@@ -151,6 +152,7 @@ func main() {
|
|||||||
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
localOverridesStr = os.Getenv("LOCAL_OVERRIDES")
|
||||||
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
|
trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS")
|
||||||
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
proxyProtocolStr := os.Getenv("PROXY_PROTOCOL")
|
||||||
|
doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING")
|
||||||
|
|
||||||
if interfaceName == "" {
|
if interfaceName == "" {
|
||||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||||
@@ -222,6 +224,13 @@ func main() {
|
|||||||
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
|
flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if doTrafficShapingStr != "" {
|
||||||
|
doTrafficShaping = strings.ToLower(doTrafficShapingStr) == "true"
|
||||||
|
}
|
||||||
|
if doTrafficShapingStr == "" {
|
||||||
|
flag.BoolVar(&doTrafficShaping, "do-traffic-shaping", false, "Whether to set up traffic shaping rules for peers (requires tc command and root privileges)")
|
||||||
|
}
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
logger.Init()
|
logger.Init()
|
||||||
@@ -886,17 +895,23 @@ func addPeerInternal(peer Peer) error {
|
|||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
return fmt.Errorf("failed to parse public key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug("Adding peer %s with AllowedIPs: %v", peer.PublicKey, peer.AllowedIPs)
|
||||||
|
|
||||||
// parse allowed IPs into array of net.IPNet
|
// parse allowed IPs into array of net.IPNet
|
||||||
var allowedIPs []net.IPNet
|
var allowedIPs []net.IPNet
|
||||||
var wgIPs []string
|
var wgIPs []string
|
||||||
for _, ipStr := range peer.AllowedIPs {
|
for _, ipStr := range peer.AllowedIPs {
|
||||||
|
logger.Debug("Parsing AllowedIP: %s", ipStr)
|
||||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Warn("Failed to parse allowed IP '%s' for peer %s: %v", ipStr, peer.PublicKey, err)
|
||||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||||
}
|
}
|
||||||
allowedIPs = append(allowedIPs, *ipNet)
|
allowedIPs = append(allowedIPs, *ipNet)
|
||||||
// Extract the IP address from the CIDR for relay cleanup
|
// Extract the IP address from the CIDR for relay cleanup
|
||||||
wgIPs = append(wgIPs, ipNet.IP.String())
|
extractedIP := ipNet.IP.String()
|
||||||
|
wgIPs = append(wgIPs, extractedIP)
|
||||||
|
logger.Debug("Extracted IP %s from AllowedIP %s", extractedIP, ipStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConfig := wgtypes.PeerConfig{
|
peerConfig := wgtypes.PeerConfig{
|
||||||
@@ -912,6 +927,18 @@ func addPeerInternal(peer Peer) error {
|
|||||||
return fmt.Errorf("failed to add peer: %v", err)
|
return fmt.Errorf("failed to add peer: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup bandwidth limiting for each peer IP
|
||||||
|
if doTrafficShaping {
|
||||||
|
logger.Debug("doTrafficShaping is true, setting up bandwidth limits for %d IPs", len(wgIPs))
|
||||||
|
for _, wgIP := range wgIPs {
|
||||||
|
if err := setupPeerBandwidthLimit(wgIP); err != nil {
|
||||||
|
logger.Warn("Failed to setup bandwidth limit for peer IP %s: %v", wgIP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("doTrafficShaping is false, skipping bandwidth limit setup")
|
||||||
|
}
|
||||||
|
|
||||||
// Clear relay connections for the peer's WireGuard IPs
|
// Clear relay connections for the peer's WireGuard IPs
|
||||||
if proxyRelay != nil {
|
if proxyRelay != nil {
|
||||||
for _, wgIP := range wgIPs {
|
for _, wgIP := range wgIPs {
|
||||||
@@ -956,19 +983,17 @@ func removePeerInternal(publicKey string) error {
|
|||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
return fmt.Errorf("failed to parse public key: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get current peer info before removing to clear relay connections
|
// Get current peer info before removing to clear relay connections and bandwidth limits
|
||||||
var wgIPs []string
|
var wgIPs []string
|
||||||
if proxyRelay != nil {
|
device, err := wgClient.Device(interfaceName)
|
||||||
device, err := wgClient.Device(interfaceName)
|
if err == nil {
|
||||||
if err == nil {
|
for _, peer := range device.Peers {
|
||||||
for _, peer := range device.Peers {
|
if peer.PublicKey.String() == publicKey {
|
||||||
if peer.PublicKey.String() == publicKey {
|
// Extract WireGuard IPs from this peer's allowed IPs
|
||||||
// Extract WireGuard IPs from this peer's allowed IPs
|
for _, allowedIP := range peer.AllowedIPs {
|
||||||
for _, allowedIP := range peer.AllowedIPs {
|
wgIPs = append(wgIPs, allowedIP.IP.String())
|
||||||
wgIPs = append(wgIPs, allowedIP.IP.String())
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -986,6 +1011,15 @@ func removePeerInternal(publicKey string) error {
|
|||||||
return fmt.Errorf("failed to remove peer: %v", err)
|
return fmt.Errorf("failed to remove peer: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove bandwidth limits for each peer IP
|
||||||
|
if doTrafficShaping {
|
||||||
|
for _, wgIP := range wgIPs {
|
||||||
|
if err := removePeerBandwidthLimit(wgIP); err != nil {
|
||||||
|
logger.Warn("Failed to remove bandwidth limit for peer IP %s: %v", wgIP, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Clear relay connections for the peer's WireGuard IPs
|
// Clear relay connections for the peer's WireGuard IPs
|
||||||
if proxyRelay != nil {
|
if proxyRelay != nil {
|
||||||
for _, wgIP := range wgIPs {
|
for _, wgIP := range wgIPs {
|
||||||
@@ -1315,3 +1349,155 @@ func monitorMemory(limit uint64) {
|
|||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP
|
||||||
|
// Currently hardcoded to 20 Mbps per peer
|
||||||
|
func setupPeerBandwidthLimit(peerIP string) error {
|
||||||
|
logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP)
|
||||||
|
const bandwidthLimit = "50mbit" // 50 Mbps limit per peer
|
||||||
|
|
||||||
|
// Parse the IP to get just the IP address (strip any CIDR notation if present)
|
||||||
|
ip := peerIP
|
||||||
|
if strings.Contains(peerIP, "/") {
|
||||||
|
parsedIP, _, err := net.ParseCIDR(peerIP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse peer IP: %v", err)
|
||||||
|
}
|
||||||
|
ip = parsedIP.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, ensure we have a root qdisc on the interface (HTB - Hierarchical Token Bucket)
|
||||||
|
// Check if qdisc already exists
|
||||||
|
cmd := exec.Command("tc", "qdisc", "show", "dev", interfaceName)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check qdisc: %v, output: %s", err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no HTB qdisc exists, create one
|
||||||
|
if !strings.Contains(string(output), "htb") {
|
||||||
|
cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "root", "handle", "1:", "htb", "default", "9999")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("failed to add root qdisc: %v, output: %s", err, string(output))
|
||||||
|
}
|
||||||
|
logger.Info("Created HTB root qdisc on %s", interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique class ID based on the IP address
|
||||||
|
// We'll use the last octet of the IP as part of the class ID
|
||||||
|
ipParts := strings.Split(ip, ".")
|
||||||
|
if len(ipParts) != 4 {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
lastOctet := ipParts[3]
|
||||||
|
classID := fmt.Sprintf("1:%s", lastOctet)
|
||||||
|
logger.Debug("Generated class ID %s for peer IP %s", classID, ip)
|
||||||
|
|
||||||
|
// Create a class for this peer with bandwidth limit
|
||||||
|
cmd = exec.Command("tc", "class", "add", "dev", interfaceName, "parent", "1:", "classid", classID,
|
||||||
|
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
logger.Debug("tc class add failed for %s: %v, output: %s", ip, err, string(output))
|
||||||
|
// If class already exists, try to replace it
|
||||||
|
if strings.Contains(string(output), "File exists") {
|
||||||
|
cmd = exec.Command("tc", "class", "replace", "dev", interfaceName, "parent", "1:", "classid", classID,
|
||||||
|
"htb", "rate", bandwidthLimit, "ceil", bandwidthLimit)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("failed to replace class: %v, output: %s", err, string(output))
|
||||||
|
}
|
||||||
|
logger.Debug("Successfully replaced existing class %s for peer IP %s", classID, ip)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to add class: %v, output: %s", err, string(output))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Debug("Successfully added new class %s for peer IP %s", classID, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a filter to match traffic from this peer IP (ingress)
|
||||||
|
cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:",
|
||||||
|
"prio", "1", "u32", "match", "ip", "src", ip, "flowid", classID)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
// If filter fails, log but don't fail the peer addition
|
||||||
|
logger.Warn("Failed to add ingress filter for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add a filter to match traffic to this peer IP (egress)
|
||||||
|
cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:",
|
||||||
|
"prio", "1", "u32", "match", "ip", "dst", ip, "flowid", classID)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
// If filter fails, log but don't fail the peer addition
|
||||||
|
logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Setup bandwidth limit of %s for peer IP %s (class %s)", bandwidthLimit, ip, classID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removePeerBandwidthLimit removes TC rules for a specific peer IP
|
||||||
|
func removePeerBandwidthLimit(peerIP string) error {
|
||||||
|
// Parse the IP to get just the IP address
|
||||||
|
ip := peerIP
|
||||||
|
if strings.Contains(peerIP, "/") {
|
||||||
|
parsedIP, _, err := net.ParseCIDR(peerIP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse peer IP: %v", err)
|
||||||
|
}
|
||||||
|
ip = parsedIP.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the class ID based on the IP
|
||||||
|
ipParts := strings.Split(ip, ".")
|
||||||
|
if len(ipParts) != 4 {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
lastOctet := ipParts[3]
|
||||||
|
classID := fmt.Sprintf("1:%s", lastOctet)
|
||||||
|
|
||||||
|
// Remove filters for this IP
|
||||||
|
// List all filters to find the ones for this class
|
||||||
|
cmd := exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "1:")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to list filters for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||||
|
} else {
|
||||||
|
// Parse the output to find filter handles that match this classID
|
||||||
|
// The output format includes lines like:
|
||||||
|
// filter parent 1: protocol ip pref 1 u32 chain 0 fh 800::800 order 2048 key ht 800 bkt 0 flowid 1:4
|
||||||
|
lines := strings.Split(string(output), "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
// Look for lines containing our flowid (classID)
|
||||||
|
if strings.Contains(line, "flowid "+classID) && strings.Contains(line, "fh ") {
|
||||||
|
// Extract handle (format: fh 800::800)
|
||||||
|
parts := strings.Fields(line)
|
||||||
|
var handle string
|
||||||
|
for j, part := range parts {
|
||||||
|
if part == "fh" && j+1 < len(parts) {
|
||||||
|
handle = parts[j+1]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if handle != "" {
|
||||||
|
// Delete this filter using the handle
|
||||||
|
delCmd := exec.Command("tc", "filter", "del", "dev", interfaceName, "parent", "1:", "handle", handle, "prio", "1", "u32")
|
||||||
|
if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil {
|
||||||
|
logger.Debug("Failed to delete filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput))
|
||||||
|
} else {
|
||||||
|
logger.Debug("Deleted filter handle %s for peer IP %s", handle, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the class
|
||||||
|
cmd = exec.Command("tc", "class", "del", "dev", interfaceName, "classid", classID)
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
// It's okay if the class doesn't exist
|
||||||
|
if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") {
|
||||||
|
logger.Warn("Failed to remove class for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removed bandwidth limit for peer IP %s (class %s)", ip, classID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user