mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Replace netstack and remove proxy
This commit is contained in:
318
main.go
318
main.go
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
@@ -18,16 +16,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/newt/wg"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@@ -62,59 +59,93 @@ func fixKey(key string) string {
|
||||
return hex.EncodeToString(decoded)
|
||||
}
|
||||
|
||||
func ping(tnet *netstack.Net, dst string) error {
|
||||
logger.Info("Pinging %s", dst)
|
||||
socket, err := tnet.Dial("ping4", dst)
|
||||
const (
|
||||
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||
)
|
||||
|
||||
func ping(dev *device.Device, dst string) error {
|
||||
logger.Info("Pinging %s over WireGuard tunnel", dst)
|
||||
|
||||
// Create a raw socket for ICMP
|
||||
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ICMP socket: %w", err)
|
||||
}
|
||||
defer socket.Close()
|
||||
defer conn.Close()
|
||||
|
||||
requestPing := icmp.Echo{
|
||||
Seq: rand.Intn(1 << 16),
|
||||
Data: []byte("gopher burrow"),
|
||||
// Parse destination IP
|
||||
dstIP := net.ParseIP(dst)
|
||||
if dstIP == nil {
|
||||
return fmt.Errorf("invalid destination IP: %s", dst)
|
||||
}
|
||||
|
||||
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
||||
// Create ICMP message
|
||||
requestPing := icmp.Echo{
|
||||
ID: os.Getpid() & 0xffff,
|
||||
Seq: rand.Intn(1 << 16),
|
||||
Data: []byte("wireguard ping"),
|
||||
}
|
||||
|
||||
msg := icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &requestPing,
|
||||
}
|
||||
|
||||
// Marshal the message
|
||||
icmpBytes, err := msg.Marshal(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ICMP message: %w", err)
|
||||
}
|
||||
|
||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
||||
// Set read deadline
|
||||
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
||||
return fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Send the ping
|
||||
start := time.Now()
|
||||
_, err = socket.Write(icmpBytes)
|
||||
_, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
n, err := socket.Read(icmpBytes[:])
|
||||
// Wait for reply
|
||||
reply := make([]byte, 1500)
|
||||
n, peer, err := conn.ReadFrom(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
||||
// Parse reply
|
||||
replyMsg, err := icmp.ParseMessage(1, reply[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse ICMP packet: %w", err)
|
||||
return fmt.Errorf("failed to parse ICMP reply: %w", err)
|
||||
}
|
||||
|
||||
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
||||
// Verify reply
|
||||
switch replyMsg.Type {
|
||||
case ipv4.ICMPTypeEchoReply:
|
||||
replyEcho, ok := replyMsg.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
|
||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body)
|
||||
}
|
||||
if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq {
|
||||
return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d",
|
||||
replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg)
|
||||
}
|
||||
|
||||
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
||||
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
|
||||
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
|
||||
}
|
||||
|
||||
logger.Info("Ping latency: %v", time.Since(start))
|
||||
duration := time.Since(start)
|
||||
logger.Info("Ping reply from %v: time=%v", peer, duration)
|
||||
return nil
|
||||
}
|
||||
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
||||
func startPingCheck(dev *device.Device, serverIP string, stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -122,10 +153,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := ping(tnet, serverIP)
|
||||
err := ping(dev, serverIP)
|
||||
if err != nil {
|
||||
logger.Warn("Periodic ping failed: %v", err)
|
||||
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
|
||||
logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable")
|
||||
}
|
||||
case <-stopChan:
|
||||
logger.Info("Stopping ping check")
|
||||
@@ -135,7 +166,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||
}()
|
||||
}
|
||||
|
||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
func pingWithRetry(dev *device.Device, dst string) error {
|
||||
const (
|
||||
maxAttempts = 5
|
||||
retryDelay = 2 * time.Second
|
||||
@@ -145,7 +176,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
||||
|
||||
if err := ping(tnet, dst); err != nil {
|
||||
if err := ping(dev, dst); err != nil {
|
||||
lastErr = err
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
|
||||
@@ -161,7 +192,6 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
||||
}
|
||||
|
||||
@@ -335,29 +365,13 @@ func main() {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Create WireGuard service
|
||||
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
defer wgService.Close()
|
||||
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
var dev *device.Device
|
||||
var pm *proxy.ProxyManager
|
||||
var connected bool
|
||||
var wgData WgData
|
||||
|
||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||
client.RegisterHandler("client/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
if pm != nil {
|
||||
pm.Stop()
|
||||
}
|
||||
if dev != nil {
|
||||
dev.Close()
|
||||
}
|
||||
client.Close()
|
||||
})
|
||||
|
||||
@@ -365,13 +379,12 @@ func main() {
|
||||
defer close(pingStopChan)
|
||||
|
||||
// Register handlers for different message types
|
||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||
client.RegisterHandler("client/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received registration message")
|
||||
|
||||
if connected {
|
||||
logger.Info("Already connected! But I will send a ping anyway...")
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
err := pingWithRetry(dev, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
@@ -391,17 +404,39 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received: %+v", msg)
|
||||
tun, tnet, err = netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
||||
mtuInt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
// logger.Info("Received: %+v", msg)
|
||||
// tun, tnet, err = netstack.CreateNetTUN(
|
||||
// []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
// []netip.Addr{netip.MustParseAddr(dns)},
|
||||
// mtuInt)
|
||||
// if err != nil {
|
||||
// logger.Error("Failed to create TUN device: %v", err)
|
||||
// }
|
||||
|
||||
tdev, err := func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
if tunFdStr == "" {
|
||||
return tun.CreateTUN(interfaceName, mtuInt)
|
||||
}
|
||||
|
||||
// construct tun device from supplied fd
|
||||
|
||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
return tun.CreateTUNFromFile(file, mtuInt)
|
||||
}()
|
||||
|
||||
// Create WireGuard device
|
||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
||||
dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger(
|
||||
mapToWireGuardLogLevel(loggerLevel),
|
||||
"wireguard: ",
|
||||
))
|
||||
@@ -433,7 +468,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||
// Ping to bring the tunnel up on the server side quickly
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(dev, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
@@ -441,114 +476,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
|
||||
if !connected {
|
||||
logger.Info("Starting ping check")
|
||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
||||
startPingCheck(dev, wgData.ServerIP, pingStopChan)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
pm = proxy.NewProxyManager(tnet)
|
||||
|
||||
connected = true
|
||||
|
||||
// add the targets if there are any
|
||||
if len(wgData.Targets.TCP) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
||||
}
|
||||
|
||||
if len(wgData.Targets.UDP) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.OnConnect(func() error {
|
||||
publicKey := privateKey.PublicKey()
|
||||
logger.Debug("Public key: %s", publicKey)
|
||||
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
err := client.SendMessage("client/wg/register", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -574,62 +511,3 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
// Cleanup
|
||||
dev.Close()
|
||||
}
|
||||
|
||||
func parseTargetData(data interface{}) (TargetData, error) {
|
||||
var targetData TargetData
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return targetData, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return targetData, err
|
||||
}
|
||||
return targetData, nil
|
||||
}
|
||||
|
||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||
for _, t := range targetData.Targets {
|
||||
// Split the first number off of the target with : separator and use as the port
|
||||
parts := strings.Split(t, ":")
|
||||
if len(parts) != 3 {
|
||||
logger.Info("Invalid target format: %s", t)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the port as an int
|
||||
port := 0
|
||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
||||
if err != nil {
|
||||
logger.Info("Invalid port: %s", parts[0])
|
||||
continue
|
||||
}
|
||||
|
||||
if action == "add" {
|
||||
target := parts[1] + ":" + parts[2]
|
||||
// Only remove the specific target if it exists
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
// Ignore "target not found" errors as this is expected for new targets
|
||||
if !strings.Contains(err.Error(), "target not found") {
|
||||
logger.Error("Failed to remove existing target: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new target
|
||||
pm.AddTarget(proto, tunnelIP, port, target)
|
||||
|
||||
} else if action == "remove" {
|
||||
logger.Info("Removing target with port %d", port)
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
logger.Error("Failed to remove target: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
352
proxy/manager.go
352
proxy/manager.go
@@ -1,352 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
)
|
||||
|
||||
// Target represents a proxy target with its address and port
|
||||
type Target struct {
|
||||
Address string
|
||||
Port int
|
||||
}
|
||||
|
||||
// ProxyManager handles the creation and management of proxy connections
|
||||
type ProxyManager struct {
|
||||
tnet *netstack.Net
|
||||
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
|
||||
udpTargets map[string]map[int]string
|
||||
listeners []*gonet.TCPListener
|
||||
udpConns []*gonet.UDPConn
|
||||
running bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProxyManager creates a new proxy manager instance
|
||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
||||
return &ProxyManager{
|
||||
tnet: tnet,
|
||||
tcpTargets: make(map[string]map[int]string),
|
||||
udpTargets: make(map[string]map[int]string),
|
||||
listeners: make([]*gonet.TCPListener, 0),
|
||||
udpConns: make([]*gonet.UDPConn, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddTarget adds as new target for proxying
|
||||
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
switch proto {
|
||||
case "tcp":
|
||||
if pm.tcpTargets[listenIP] == nil {
|
||||
pm.tcpTargets[listenIP] = make(map[int]string)
|
||||
}
|
||||
pm.tcpTargets[listenIP][port] = targetAddr
|
||||
case "udp":
|
||||
if pm.udpTargets[listenIP] == nil {
|
||||
pm.udpTargets[listenIP] = make(map[int]string)
|
||||
}
|
||||
pm.udpTargets[listenIP][port] = targetAddr
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
if pm.running {
|
||||
return pm.startTarget(proto, listenIP, port, targetAddr)
|
||||
} else {
|
||||
logger.Debug("Not adding target because not running")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
switch proto {
|
||||
case "tcp":
|
||||
if targets, ok := pm.tcpTargets[listenIP]; ok {
|
||||
delete(targets, port)
|
||||
// Remove and close the corresponding TCP listener
|
||||
for i, listener := range pm.listeners {
|
||||
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
|
||||
listener.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Remove from slice
|
||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
case "udp":
|
||||
if targets, ok := pm.udpTargets[listenIP]; ok {
|
||||
delete(targets, port)
|
||||
// Remove and close the corresponding UDP connection
|
||||
for i, conn := range pm.udpConns {
|
||||
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
|
||||
conn.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Remove from slice
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins listening for all configured proxy targets
|
||||
func (pm *ProxyManager) Start() error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start TCP targets
|
||||
for listenIP, targets := range pm.tcpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
|
||||
return fmt.Errorf("failed to start TCP target: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start UDP targets
|
||||
for listenIP, targets := range pm.udpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
|
||||
return fmt.Errorf("failed to start UDP target: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Stop() error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set running to false first to signal handlers to stop
|
||||
pm.running = false
|
||||
|
||||
// Close TCP listeners
|
||||
for i := len(pm.listeners) - 1; i >= 0; i-- {
|
||||
listener := pm.listeners[i]
|
||||
if err := listener.Close(); err != nil {
|
||||
logger.Error("Error closing TCP listener: %v", err)
|
||||
}
|
||||
// Remove from slice
|
||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
||||
}
|
||||
|
||||
// Close UDP connections
|
||||
for i := len(pm.udpConns) - 1; i >= 0; i-- {
|
||||
conn := pm.udpConns[i]
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Error("Error closing UDP connection: %v", err)
|
||||
}
|
||||
// Remove from slice
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
}
|
||||
|
||||
// Clear the target maps
|
||||
for k := range pm.tcpTargets {
|
||||
delete(pm.tcpTargets, k)
|
||||
}
|
||||
for k := range pm.udpTargets {
|
||||
delete(pm.udpTargets, k)
|
||||
}
|
||||
|
||||
// Give active connections a chance to close gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
|
||||
switch proto {
|
||||
case "tcp":
|
||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TCP listener: %v", err)
|
||||
}
|
||||
|
||||
pm.listeners = append(pm.listeners, listener)
|
||||
go pm.handleTCPProxy(listener, targetAddr)
|
||||
|
||||
case "udp":
|
||||
addr := &net.UDPAddr{Port: port}
|
||||
conn, err := pm.tnet.ListenUDP(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create UDP listener: %v", err)
|
||||
}
|
||||
|
||||
pm.udpConns = append(pm.udpConns, conn)
|
||||
go pm.handleUDPProxy(conn, targetAddr)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
||||
}
|
||||
|
||||
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// Check if we're shutting down or the listener was closed
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for specific network errors that indicate the listener is closed
|
||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error accepting TCP connection: %v", err)
|
||||
// Don't hammer the CPU if we hit a temporary error
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
go func() {
|
||||
target, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Create a WaitGroup to ensure both copy operations complete
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(target, conn)
|
||||
target.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
io.Copy(conn, target)
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// Wait for both copies to complete
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
buffer := make([]byte, 65507) // Max UDP packet size
|
||||
clientConns := make(map[string]*net.UDPConn)
|
||||
var clientsMutex sync.RWMutex
|
||||
|
||||
for {
|
||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
// Check for connection closed conditions
|
||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("UDP connection closed, stopping proxy handler")
|
||||
|
||||
// Clean up existing client connections
|
||||
clientsMutex.Lock()
|
||||
for _, targetConn := range clientConns {
|
||||
targetConn.Close()
|
||||
}
|
||||
clientConns = nil
|
||||
clientsMutex.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
clientKey := remoteAddr.String()
|
||||
clientsMutex.RLock()
|
||||
targetConn, exists := clientConns[clientKey]
|
||||
clientsMutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error resolving target address: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error connecting to target: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
clientsMutex.Lock()
|
||||
clientConns[clientKey] = targetConn
|
||||
clientsMutex.Unlock()
|
||||
|
||||
go func() {
|
||||
buffer := make([]byte, 65507)
|
||||
for {
|
||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
logger.Error("Error reading from target: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
||||
if err != nil {
|
||||
logger.Error("Error writing to client: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
_, err = targetConn.Write(buffer[:n])
|
||||
if err != nil {
|
||||
logger.Error("Error writing to target: %v", err)
|
||||
targetConn.Close()
|
||||
clientsMutex.Lock()
|
||||
delete(clientConns, clientKey)
|
||||
clientsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
606
wg/wg.go
606
wg/wg.go
@@ -1,606 +0,0 @@
|
||||
package wg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var (
|
||||
interfaceName string
|
||||
listenAddr string
|
||||
mtuInt int
|
||||
lastReadings = make(map[string]PeerReading)
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
type WgConfig struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
ListenPort int `json:"listenPort"`
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
}
|
||||
|
||||
type Peer struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
AllowedIPs []string `json:"allowedIps"`
|
||||
}
|
||||
|
||||
type PeerBandwidth struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
BytesIn float64 `json:"bytesIn"`
|
||||
BytesOut float64 `json:"bytesOut"`
|
||||
}
|
||||
|
||||
type PeerReading struct {
|
||||
BytesReceived int64
|
||||
BytesTransmitted int64
|
||||
LastChecked time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
wgClient *wgctrl.Client
|
||||
)
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
wgClient *wgctrl.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
reachableAt string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
||||
wgClient, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
||||
}
|
||||
|
||||
key := wgtypes.Key{}
|
||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
||||
// generate a new private key
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to generate private key: %v", err)
|
||||
}
|
||||
// save the key to the file
|
||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to save private key: %v", err)
|
||||
}
|
||||
} else {
|
||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to read private key: %v", err)
|
||||
}
|
||||
key, err = wgtypes.ParseKey(string(keyData))
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to parse private key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
service := &WireGuardService{
|
||||
interfaceName: interfaceName,
|
||||
mtu: mtu,
|
||||
client: wsClient,
|
||||
wgClient: wgClient,
|
||||
key: key,
|
||||
reachableAt: reachableAt,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
}
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("wg/config/receive", service.handleConfig)
|
||||
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
||||
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
||||
|
||||
// Register connect handler to initiate configuration
|
||||
wsClient.OnConnect(service.loadRemoteConfig)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) Close() {
|
||||
s.client.Close()
|
||||
wgClient.Close()
|
||||
}
|
||||
|
||||
func (s *WireGuardService) loadRemoteConfig() error {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
||||
|
||||
go s.periodicBandwidthCheck()
|
||||
|
||||
err := s.client.SendMessage("wg/config/get", body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send config request: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
var config WgConfig
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &config); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
|
||||
// Ensure the WireGuard interface and peers are configured
|
||||
if err := s.ensureWireguardInterface(config); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
||||
}
|
||||
|
||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Check if the WireGuard interface exists
|
||||
_, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
||||
// Interface doesn't exist, so create it
|
||||
err = s.createWireGuardInterface()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
||||
}
|
||||
logger.Info("Created WireGuard interface %s\n", interfaceName)
|
||||
} else {
|
||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Info("WireGuard interface %s already exists\n", interfaceName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assign IP address to the interface
|
||||
err = s.assignIPAddress(wgconfig.IpAddress)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to assign IP address: %v", err)
|
||||
}
|
||||
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
|
||||
|
||||
// Check if the interface already exists
|
||||
_, err = wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface %s does not exist", interfaceName)
|
||||
}
|
||||
|
||||
// Parse the private key
|
||||
key, err := wgtypes.ParseKey(wgconfig.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse private key: %v", err)
|
||||
}
|
||||
|
||||
// Create a new WireGuard configuration
|
||||
config := wgtypes.Config{
|
||||
PrivateKey: &key,
|
||||
ListenPort: new(int),
|
||||
}
|
||||
*config.ListenPort = wgconfig.ListenPort
|
||||
|
||||
// Create and configure the WireGuard interface
|
||||
err = wgClient.ConfigureDevice(interfaceName, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// bring up the interface
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
|
||||
return fmt.Errorf("failed to set MTU: %v", err)
|
||||
}
|
||||
|
||||
if err := netlink.LinkSetUp(link); err != nil {
|
||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||
}
|
||||
|
||||
// if err := s.ensureMSSClamping(); err != nil {
|
||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
||||
// }
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) createWireGuardInterface() error {
|
||||
wgLink := &netlink.GenericLink{
|
||||
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
||||
LinkType: "wireguard",
|
||||
}
|
||||
return netlink.LinkAdd(wgLink)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
||||
link, err := netlink.LinkByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface: %v", err)
|
||||
}
|
||||
|
||||
addr, err := netlink.ParseAddr(ipAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
return netlink.AddrAdd(link, addr)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
||||
// get the current peers
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
// get the peer public keys
|
||||
var currentPeers []string
|
||||
for _, peer := range device.Peers {
|
||||
currentPeers = append(currentPeers, peer.PublicKey.String())
|
||||
}
|
||||
|
||||
// remove any peers that are not in the config
|
||||
for _, peer := range currentPeers {
|
||||
found := false
|
||||
for _, configPeer := range peers {
|
||||
if peer == configPeer.PublicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.removePeer(peer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add any peers that are in the config but not in the current peers
|
||||
for _, configPeer := range peers {
|
||||
found := false
|
||||
for _, peer := range currentPeers {
|
||||
if configPeer.PublicKey == peer {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
err := s.addPeer(configPeer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (s *WireGuardService) ensureMSSClamping() error {
|
||||
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
||||
// mssValue := mtuInt - 40
|
||||
|
||||
// // Rules to be managed - just the chains, we'll construct the full command separately
|
||||
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
||||
|
||||
// // First, try to delete any existing rules
|
||||
// for _, chain := range chains {
|
||||
// deleteCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-D", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
||||
|
||||
// // Try deletion multiple times to handle multiple existing rules
|
||||
// for i := 0; i < 3; i++ {
|
||||
// out, err := deleteCmd.CombinedOutput()
|
||||
// if err != nil {
|
||||
// // Convert exit status 1 to string for better logging
|
||||
// if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
||||
// chain, exitErr.String(), string(out))
|
||||
// }
|
||||
// break // No more rules to delete
|
||||
// }
|
||||
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Then add the new rules
|
||||
// var errors []error
|
||||
// for _, chain := range chains {
|
||||
// addCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-A", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
// logger.Info("Adding MSS clamping rule for chain %s", chain)
|
||||
|
||||
// if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
// chain, err, string(out))
|
||||
// logger.Error(errMsg)
|
||||
// errors = append(errors, fmt.Errorf(errMsg))
|
||||
// continue
|
||||
// }
|
||||
|
||||
// // Verify the rule was added
|
||||
// checkCmd := exec.Command("/usr/sbin/iptables",
|
||||
// "-t", "mangle",
|
||||
// "-C", chain,
|
||||
// "-p", "tcp",
|
||||
// "--tcp-flags", "SYN,RST", "SYN",
|
||||
// "-j", "TCPMSS",
|
||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
||||
|
||||
// if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
// chain, err, string(out))
|
||||
// logger.Error(errMsg)
|
||||
// errors = append(errors, fmt.Errorf(errMsg))
|
||||
// continue
|
||||
// }
|
||||
|
||||
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
||||
// }
|
||||
|
||||
// // If we encountered any errors, return them combined
|
||||
// if len(errors) > 0 {
|
||||
// var errMsgs []string
|
||||
// for _, err := range errors {
|
||||
// errMsgs = append(errMsgs, err.Error())
|
||||
// }
|
||||
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
||||
// strings.Join(errMsgs, "\n"))
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
||||
var peer Peer
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &peer); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
}
|
||||
|
||||
err = s.addPeer(peer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) addPeer(peer Peer) error {
|
||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
// parse allowed IPs into array of net.IPNet
|
||||
var allowedIPs []net.IPNet
|
||||
for _, ipStr := range peer.AllowedIPs {
|
||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
||||
}
|
||||
allowedIPs = append(allowedIPs, *ipNet)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
AllowedIPs: allowedIPs,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to add peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
||||
type RemoveRequest struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
}
|
||||
|
||||
var request RemoveRequest
|
||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
||||
logger.Info("Error unmarshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.removePeer(request.PublicKey); err != nil {
|
||||
logger.Info("Error removing peer: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) removePeer(publicKey string) error {
|
||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %v", err)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
||||
}
|
||||
|
||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
||||
return fmt.Errorf("failed to remove peer: %v", err)
|
||||
}
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := s.reportPeerBandwidth(); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
device, err := wgClient.Device(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
peerBandwidths := []PeerBandwidth{}
|
||||
now := time.Now()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
for _, peer := range device.Peers {
|
||||
publicKey := peer.PublicKey.String()
|
||||
currentReading := PeerReading{
|
||||
BytesReceived: peer.ReceiveBytes,
|
||||
BytesTransmitted: peer.TransmitBytes,
|
||||
LastChecked: now,
|
||||
}
|
||||
|
||||
var bytesInDiff, bytesOutDiff float64
|
||||
lastReading, exists := lastReadings[publicKey]
|
||||
|
||||
if exists {
|
||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
||||
if timeDiff > 0 {
|
||||
// Calculate bytes transferred since last reading
|
||||
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
|
||||
|
||||
// Handle counter wraparound (if the counter resets or overflows)
|
||||
if bytesInDiff < 0 {
|
||||
bytesInDiff = float64(currentReading.BytesReceived)
|
||||
}
|
||||
if bytesOutDiff < 0 {
|
||||
bytesOutDiff = float64(currentReading.BytesTransmitted)
|
||||
}
|
||||
|
||||
// Convert to MB
|
||||
bytesInMB := bytesInDiff / (1024 * 1024)
|
||||
bytesOutMB := bytesOutDiff / (1024 * 1024)
|
||||
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: bytesInMB,
|
||||
BytesOut: bytesOutMB,
|
||||
})
|
||||
} else {
|
||||
// If readings are too close together or time hasn't passed, report 0
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// For first reading of a peer, report 0 to establish baseline
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: 0,
|
||||
BytesOut: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// Update the last reading
|
||||
lastReadings[publicKey] = currentReading
|
||||
}
|
||||
|
||||
// Clean up old peers
|
||||
for publicKey := range lastReadings {
|
||||
found := false
|
||||
for _, peer := range device.Peers {
|
||||
if peer.PublicKey.String() == publicKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
delete(lastReadings, publicKey)
|
||||
}
|
||||
}
|
||||
|
||||
return peerBandwidths, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
bandwidths, err := s.calculatePeerBandwidth()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bandwidths)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
err = s.client.SendMessage("wg/bandwidth", jsonData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user