mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Replace netstack and remove proxy
This commit is contained in:
322
main.go
322
main.go
@@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -9,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -18,16 +16,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/proxy"
|
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/newt/websocket"
|
||||||
"github.com/fosrl/newt/wg"
|
|
||||||
|
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/conn"
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,59 +59,93 @@ func fixKey(key string) string {
|
|||||||
return hex.EncodeToString(decoded)
|
return hex.EncodeToString(decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ping(tnet *netstack.Net, dst string) error {
|
const (
|
||||||
logger.Info("Pinging %s", dst)
|
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||||
socket, err := tnet.Dial("ping4", dst)
|
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||||
|
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ping(dev *device.Device, dst string) error {
|
||||||
|
logger.Info("Pinging %s over WireGuard tunnel", dst)
|
||||||
|
|
||||||
|
// Create a raw socket for ICMP
|
||||||
|
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create ICMP socket: %w", err)
|
return fmt.Errorf("failed to create ICMP socket: %w", err)
|
||||||
}
|
}
|
||||||
defer socket.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
requestPing := icmp.Echo{
|
// Parse destination IP
|
||||||
Seq: rand.Intn(1 << 16),
|
dstIP := net.ParseIP(dst)
|
||||||
Data: []byte("gopher burrow"),
|
if dstIP == nil {
|
||||||
|
return fmt.Errorf("invalid destination IP: %s", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
// Create ICMP message
|
||||||
|
requestPing := icmp.Echo{
|
||||||
|
ID: os.Getpid() & 0xffff,
|
||||||
|
Seq: rand.Intn(1 << 16),
|
||||||
|
Data: []byte("wireguard ping"),
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := icmp.Message{
|
||||||
|
Type: ipv4.ICMPTypeEcho,
|
||||||
|
Code: 0,
|
||||||
|
Body: &requestPing,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the message
|
||||||
|
icmpBytes, err := msg.Marshal(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal ICMP message: %w", err)
|
return fmt.Errorf("failed to marshal ICMP message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
// Set read deadline
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
||||||
return fmt.Errorf("failed to set read deadline: %w", err)
|
return fmt.Errorf("failed to set read deadline: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send the ping
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
_, err = socket.Write(icmpBytes)
|
_, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to write ICMP packet: %w", err)
|
return fmt.Errorf("failed to write ICMP packet: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := socket.Read(icmpBytes[:])
|
// Wait for reply
|
||||||
|
reply := make([]byte, 1500)
|
||||||
|
n, peer, err := conn.ReadFrom(reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read ICMP packet: %w", err)
|
return fmt.Errorf("failed to read ICMP packet: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
// Parse reply
|
||||||
|
replyMsg, err := icmp.ParseMessage(1, reply[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse ICMP packet: %w", err)
|
return fmt.Errorf("failed to parse ICMP reply: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
// Verify reply
|
||||||
if !ok {
|
switch replyMsg.Type {
|
||||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
|
case ipv4.ICMPTypeEchoReply:
|
||||||
|
replyEcho, ok := replyMsg.Body.(*icmp.Echo)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body)
|
||||||
|
}
|
||||||
|
if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq {
|
||||||
|
return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d",
|
||||||
|
replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
duration := time.Since(start)
|
||||||
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
|
logger.Info("Ping reply from %v: time=%v", peer, duration)
|
||||||
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Ping latency: %v", time.Since(start))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
func startPingCheck(dev *device.Device, serverIP string, stopChan chan struct{}) {
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
ticker := time.NewTicker(10 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
@@ -122,10 +153,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
err := ping(tnet, serverIP)
|
err := ping(dev, serverIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Periodic ping failed: %v", err)
|
logger.Warn("Periodic ping failed: %v", err)
|
||||||
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
|
logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable")
|
||||||
}
|
}
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
logger.Info("Stopping ping check")
|
logger.Info("Stopping ping check")
|
||||||
@@ -135,7 +166,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
func pingWithRetry(dev *device.Device, dst string) error {
|
||||||
const (
|
const (
|
||||||
maxAttempts = 5
|
maxAttempts = 5
|
||||||
retryDelay = 2 * time.Second
|
retryDelay = 2 * time.Second
|
||||||
@@ -145,7 +176,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
||||||
|
|
||||||
if err := ping(tnet, dst); err != nil {
|
if err := ping(dev, dst); err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||||
|
|
||||||
@@ -161,7 +192,6 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
|
||||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
return fmt.Errorf("unexpected error: all ping attempts failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,29 +365,13 @@ func main() {
|
|||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create WireGuard service
|
|
||||||
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
|
||||||
}
|
|
||||||
defer wgService.Close()
|
|
||||||
|
|
||||||
// Create TUN device and network stack
|
// Create TUN device and network stack
|
||||||
var tun tun.Device
|
|
||||||
var tnet *netstack.Net
|
|
||||||
var dev *device.Device
|
var dev *device.Device
|
||||||
var pm *proxy.ProxyManager
|
|
||||||
var connected bool
|
var connected bool
|
||||||
var wgData WgData
|
var wgData WgData
|
||||||
|
|
||||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
client.RegisterHandler("client/terminate", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received terminate message")
|
logger.Info("Received terminate message")
|
||||||
if pm != nil {
|
|
||||||
pm.Stop()
|
|
||||||
}
|
|
||||||
if dev != nil {
|
|
||||||
dev.Close()
|
|
||||||
}
|
|
||||||
client.Close()
|
client.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -365,13 +379,12 @@ func main() {
|
|||||||
defer close(pingStopChan)
|
defer close(pingStopChan)
|
||||||
|
|
||||||
// Register handlers for different message types
|
// Register handlers for different message types
|
||||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
client.RegisterHandler("client/wg/connect", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received registration message")
|
logger.Info("Received registration message")
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
logger.Info("Already connected! But I will send a ping anyway...")
|
logger.Info("Already connected! But I will send a ping anyway...")
|
||||||
// ping(tnet, wgData.ServerIP)
|
err := pingWithRetry(dev, wgData.ServerIP)
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Handle complete failure after all retries
|
// Handle complete failure after all retries
|
||||||
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||||
@@ -391,17 +404,39 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Received: %+v", msg)
|
// logger.Info("Received: %+v", msg)
|
||||||
tun, tnet, err = netstack.CreateNetTUN(
|
// tun, tnet, err = netstack.CreateNetTUN(
|
||||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
// []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
// []netip.Addr{netip.MustParseAddr(dns)},
|
||||||
mtuInt)
|
// mtuInt)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
logger.Error("Failed to create TUN device: %v", err)
|
// logger.Error("Failed to create TUN device: %v", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
tdev, err := func() (tun.Device, error) {
|
||||||
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
|
if tunFdStr == "" {
|
||||||
|
return tun.CreateTUN(interfaceName, mtuInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// construct tun device from supplied fd
|
||||||
|
|
||||||
|
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.SetNonblock(int(fd), true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
file := os.NewFile(uintptr(fd), "")
|
||||||
|
return tun.CreateTUNFromFile(file, mtuInt)
|
||||||
|
}()
|
||||||
|
|
||||||
// Create WireGuard device
|
// Create WireGuard device
|
||||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger(
|
||||||
mapToWireGuardLogLevel(loggerLevel),
|
mapToWireGuardLogLevel(loggerLevel),
|
||||||
"wireguard: ",
|
"wireguard: ",
|
||||||
))
|
))
|
||||||
@@ -433,7 +468,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
|||||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||||
// Ping to bring the tunnel up on the server side quickly
|
// Ping to bring the tunnel up on the server side quickly
|
||||||
// ping(tnet, wgData.ServerIP)
|
// ping(tnet, wgData.ServerIP)
|
||||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
err = pingWithRetry(dev, wgData.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Handle complete failure after all retries
|
// Handle complete failure after all retries
|
||||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||||
@@ -441,114 +476,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
|||||||
|
|
||||||
if !connected {
|
if !connected {
|
||||||
logger.Info("Starting ping check")
|
logger.Info("Starting ping check")
|
||||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
startPingCheck(dev, wgData.ServerIP, pingStopChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create proxy manager
|
|
||||||
pm = proxy.NewProxyManager(tnet)
|
|
||||||
|
|
||||||
connected = true
|
connected = true
|
||||||
|
|
||||||
// add the targets if there are any
|
|
||||||
if len(wgData.Targets.TCP) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wgData.Targets.UDP) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
|
||||||
}
|
|
||||||
|
|
||||||
err = pm.Start()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to start proxy manager: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
client.OnConnect(func() error {
|
client.OnConnect(func() error {
|
||||||
publicKey := privateKey.PublicKey()
|
publicKey := privateKey.PublicKey()
|
||||||
logger.Debug("Public key: %s", publicKey)
|
logger.Debug("Public key: %s", publicKey)
|
||||||
|
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
err := client.SendMessage("client/wg/register", map[string]interface{}{
|
||||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
"publicKey": fmt.Sprintf("%s", publicKey),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -574,62 +511,3 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
|||||||
// Cleanup
|
// Cleanup
|
||||||
dev.Close()
|
dev.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
|
||||||
var targetData TargetData
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
return targetData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
|
||||||
for _, t := range targetData.Targets {
|
|
||||||
// Split the first number off of the target with : separator and use as the port
|
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
logger.Info("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
|
||||||
port := 0
|
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid port: %s", parts[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == "add" {
|
|
||||||
target := parts[1] + ":" + parts[2]
|
|
||||||
// Only remove the specific target if it exists
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
|
||||||
// Ignore "target not found" errors as this is expected for new targets
|
|
||||||
if !strings.Contains(err.Error(), "target not found") {
|
|
||||||
logger.Error("Failed to remove existing target: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the new target
|
|
||||||
pm.AddTarget(proto, tunnelIP, port, target)
|
|
||||||
|
|
||||||
} else if action == "remove" {
|
|
||||||
logger.Info("Removing target with port %d", port)
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to remove target: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
352
proxy/manager.go
352
proxy/manager.go
@@ -1,352 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Target represents a proxy target with its address and port
|
|
||||||
type Target struct {
|
|
||||||
Address string
|
|
||||||
Port int
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProxyManager handles the creation and management of proxy connections
|
|
||||||
type ProxyManager struct {
|
|
||||||
tnet *netstack.Net
|
|
||||||
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
|
|
||||||
udpTargets map[string]map[int]string
|
|
||||||
listeners []*gonet.TCPListener
|
|
||||||
udpConns []*gonet.UDPConn
|
|
||||||
running bool
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewProxyManager creates a new proxy manager instance
|
|
||||||
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
|
|
||||||
return &ProxyManager{
|
|
||||||
tnet: tnet,
|
|
||||||
tcpTargets: make(map[string]map[int]string),
|
|
||||||
udpTargets: make(map[string]map[int]string),
|
|
||||||
listeners: make([]*gonet.TCPListener, 0),
|
|
||||||
udpConns: make([]*gonet.UDPConn, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTarget adds as new target for proxying
|
|
||||||
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case "tcp":
|
|
||||||
if pm.tcpTargets[listenIP] == nil {
|
|
||||||
pm.tcpTargets[listenIP] = make(map[int]string)
|
|
||||||
}
|
|
||||||
pm.tcpTargets[listenIP][port] = targetAddr
|
|
||||||
case "udp":
|
|
||||||
if pm.udpTargets[listenIP] == nil {
|
|
||||||
pm.udpTargets[listenIP] = make(map[int]string)
|
|
||||||
}
|
|
||||||
pm.udpTargets[listenIP][port] = targetAddr
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
if pm.running {
|
|
||||||
return pm.startTarget(proto, listenIP, port, targetAddr)
|
|
||||||
} else {
|
|
||||||
logger.Debug("Not adding target because not running")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
switch proto {
|
|
||||||
case "tcp":
|
|
||||||
if targets, ok := pm.tcpTargets[listenIP]; ok {
|
|
||||||
delete(targets, port)
|
|
||||||
// Remove and close the corresponding TCP listener
|
|
||||||
for i, listener := range pm.listeners {
|
|
||||||
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
|
|
||||||
listener.Close()
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
// Remove from slice
|
|
||||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
|
||||||
}
|
|
||||||
case "udp":
|
|
||||||
if targets, ok := pm.udpTargets[listenIP]; ok {
|
|
||||||
delete(targets, port)
|
|
||||||
// Remove and close the corresponding UDP connection
|
|
||||||
for i, conn := range pm.udpConns {
|
|
||||||
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
|
|
||||||
conn.Close()
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
// Remove from slice
|
|
||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("target not found: %s:%d", listenIP, port)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start begins listening for all configured proxy targets
|
|
||||||
func (pm *ProxyManager) Start() error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
if pm.running {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start TCP targets
|
|
||||||
for listenIP, targets := range pm.tcpTargets {
|
|
||||||
for port, targetAddr := range targets {
|
|
||||||
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
|
|
||||||
return fmt.Errorf("failed to start TCP target: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start UDP targets
|
|
||||||
for listenIP, targets := range pm.udpTargets {
|
|
||||||
for port, targetAddr := range targets {
|
|
||||||
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
|
|
||||||
return fmt.Errorf("failed to start UDP target: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.running = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) Stop() error {
|
|
||||||
pm.mutex.Lock()
|
|
||||||
defer pm.mutex.Unlock()
|
|
||||||
|
|
||||||
if !pm.running {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set running to false first to signal handlers to stop
|
|
||||||
pm.running = false
|
|
||||||
|
|
||||||
// Close TCP listeners
|
|
||||||
for i := len(pm.listeners) - 1; i >= 0; i-- {
|
|
||||||
listener := pm.listeners[i]
|
|
||||||
if err := listener.Close(); err != nil {
|
|
||||||
logger.Error("Error closing TCP listener: %v", err)
|
|
||||||
}
|
|
||||||
// Remove from slice
|
|
||||||
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close UDP connections
|
|
||||||
for i := len(pm.udpConns) - 1; i >= 0; i-- {
|
|
||||||
conn := pm.udpConns[i]
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
logger.Error("Error closing UDP connection: %v", err)
|
|
||||||
}
|
|
||||||
// Remove from slice
|
|
||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the target maps
|
|
||||||
for k := range pm.tcpTargets {
|
|
||||||
delete(pm.tcpTargets, k)
|
|
||||||
}
|
|
||||||
for k := range pm.udpTargets {
|
|
||||||
delete(pm.udpTargets, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give active connections a chance to close gracefully
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
|
|
||||||
switch proto {
|
|
||||||
case "tcp":
|
|
||||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create TCP listener: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.listeners = append(pm.listeners, listener)
|
|
||||||
go pm.handleTCPProxy(listener, targetAddr)
|
|
||||||
|
|
||||||
case "udp":
|
|
||||||
addr := &net.UDPAddr{Port: port}
|
|
||||||
conn, err := pm.tnet.ListenUDP(addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create UDP listener: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.udpConns = append(pm.udpConns, conn)
|
|
||||||
go pm.handleUDPProxy(conn, targetAddr)
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported protocol: %s", proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
// Check if we're shutting down or the listener was closed
|
|
||||||
if !pm.running {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for specific network errors that indicate the listener is closed
|
|
||||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
|
||||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Error("Error accepting TCP connection: %v", err)
|
|
||||||
// Don't hammer the CPU if we hit a temporary error
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
target, err := net.Dial("tcp", targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error connecting to target: %v", err)
|
|
||||||
conn.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a WaitGroup to ensure both copy operations complete
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
io.Copy(target, conn)
|
|
||||||
target.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
io.Copy(conn, target)
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Wait for both copies to complete
|
|
||||||
wg.Wait()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|
||||||
buffer := make([]byte, 65507) // Max UDP packet size
|
|
||||||
clientConns := make(map[string]*net.UDPConn)
|
|
||||||
var clientsMutex sync.RWMutex
|
|
||||||
|
|
||||||
for {
|
|
||||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
|
||||||
if err != nil {
|
|
||||||
if !pm.running {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for connection closed conditions
|
|
||||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
|
||||||
logger.Info("UDP connection closed, stopping proxy handler")
|
|
||||||
|
|
||||||
// Clean up existing client connections
|
|
||||||
clientsMutex.Lock()
|
|
||||||
for _, targetConn := range clientConns {
|
|
||||||
targetConn.Close()
|
|
||||||
}
|
|
||||||
clientConns = nil
|
|
||||||
clientsMutex.Unlock()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Error("Error reading UDP packet: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
clientKey := remoteAddr.String()
|
|
||||||
clientsMutex.RLock()
|
|
||||||
targetConn, exists := clientConns[clientKey]
|
|
||||||
clientsMutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error resolving target address: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error connecting to target: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
clientsMutex.Lock()
|
|
||||||
clientConns[clientKey] = targetConn
|
|
||||||
clientsMutex.Unlock()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
buffer := make([]byte, 65507)
|
|
||||||
for {
|
|
||||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error reading from target: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = conn.WriteTo(buffer[:n], remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error writing to client: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = targetConn.Write(buffer[:n])
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error writing to target: %v", err)
|
|
||||||
targetConn.Close()
|
|
||||||
clientsMutex.Lock()
|
|
||||||
delete(clientConns, clientKey)
|
|
||||||
clientsMutex.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
606
wg/wg.go
606
wg/wg.go
@@ -1,606 +0,0 @@
|
|||||||
package wg
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
"github.com/fosrl/newt/websocket"
|
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
interfaceName string
|
|
||||||
listenAddr string
|
|
||||||
mtuInt int
|
|
||||||
lastReadings = make(map[string]PeerReading)
|
|
||||||
mu sync.Mutex
|
|
||||||
)
|
|
||||||
|
|
||||||
type WgConfig struct {
|
|
||||||
PrivateKey string `json:"privateKey"`
|
|
||||||
ListenPort int `json:"listenPort"`
|
|
||||||
IpAddress string `json:"ipAddress"`
|
|
||||||
Peers []Peer `json:"peers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Peer struct {
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
AllowedIPs []string `json:"allowedIps"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PeerBandwidth struct {
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
BytesIn float64 `json:"bytesIn"`
|
|
||||||
BytesOut float64 `json:"bytesOut"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PeerReading struct {
|
|
||||||
BytesReceived int64
|
|
||||||
BytesTransmitted int64
|
|
||||||
LastChecked time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
wgClient *wgctrl.Client
|
|
||||||
)
|
|
||||||
|
|
||||||
type WireGuardService struct {
|
|
||||||
interfaceName string
|
|
||||||
mtu int
|
|
||||||
client *websocket.Client
|
|
||||||
wgClient *wgctrl.Client
|
|
||||||
config WgConfig
|
|
||||||
key wgtypes.Key
|
|
||||||
reachableAt string
|
|
||||||
lastReadings map[string]PeerReading
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
|
|
||||||
wgClient, err := wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
key := wgtypes.Key{}
|
|
||||||
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
|
|
||||||
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
|
|
||||||
// generate a new private key
|
|
||||||
key, err = wgtypes.GeneratePrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
// save the key to the file
|
|
||||||
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to save private key: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
keyData, err := os.ReadFile(generateAndSaveKeyTo)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to read private key: %v", err)
|
|
||||||
}
|
|
||||||
key, err = wgtypes.ParseKey(string(keyData))
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to parse private key: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
service := &WireGuardService{
|
|
||||||
interfaceName: interfaceName,
|
|
||||||
mtu: mtu,
|
|
||||||
client: wsClient,
|
|
||||||
wgClient: wgClient,
|
|
||||||
key: key,
|
|
||||||
reachableAt: reachableAt,
|
|
||||||
lastReadings: make(map[string]PeerReading),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register websocket handlers
|
|
||||||
wsClient.RegisterHandler("wg/config/receive", service.handleConfig)
|
|
||||||
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
|
|
||||||
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
|
|
||||||
|
|
||||||
// Register connect handler to initiate configuration
|
|
||||||
wsClient.OnConnect(service.loadRemoteConfig)
|
|
||||||
|
|
||||||
return service, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) Close() {
|
|
||||||
s.client.Close()
|
|
||||||
wgClient.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) loadRemoteConfig() error {
|
|
||||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
|
|
||||||
|
|
||||||
go s.periodicBandwidthCheck()
|
|
||||||
|
|
||||||
err := s.client.SendMessage("wg/config/get", body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send config request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|
||||||
var config WgConfig
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &config); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.config = config
|
|
||||||
|
|
||||||
// Ensure the WireGuard interface and peers are configured
|
|
||||||
if err := s.ensureWireguardInterface(config); err != nil {
|
|
||||||
logger.Error("Failed to ensure WireGuard interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.ensureWireguardPeers(config.Peers); err != nil {
|
|
||||||
logger.Error("Failed to ensure WireGuard peers: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|
||||||
// Check if the WireGuard interface exists
|
|
||||||
_, err := netlink.LinkByName(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
if _, ok := err.(netlink.LinkNotFoundError); ok {
|
|
||||||
// Interface doesn't exist, so create it
|
|
||||||
err = s.createWireGuardInterface()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to create WireGuard interface: %v", err)
|
|
||||||
}
|
|
||||||
logger.Info("Created WireGuard interface %s\n", interfaceName)
|
|
||||||
} else {
|
|
||||||
logger.Fatal("Error checking for WireGuard interface: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Info("WireGuard interface %s already exists\n", interfaceName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assign IP address to the interface
|
|
||||||
err = s.assignIPAddress(wgconfig.IpAddress)
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("Failed to assign IP address: %v", err)
|
|
||||||
}
|
|
||||||
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
|
|
||||||
|
|
||||||
// Check if the interface already exists
|
|
||||||
_, err = wgClient.Device(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("interface %s does not exist", interfaceName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the private key
|
|
||||||
key, err := wgtypes.ParseKey(wgconfig.PrivateKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new WireGuard configuration
|
|
||||||
config := wgtypes.Config{
|
|
||||||
PrivateKey: &key,
|
|
||||||
ListenPort: new(int),
|
|
||||||
}
|
|
||||||
*config.ListenPort = wgconfig.ListenPort
|
|
||||||
|
|
||||||
// Create and configure the WireGuard interface
|
|
||||||
err = wgClient.ConfigureDevice(interfaceName, config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// bring up the interface
|
|
||||||
link, err := netlink.LinkByName(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
|
|
||||||
return fmt.Errorf("failed to set MTU: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.LinkSetUp(link); err != nil {
|
|
||||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// if err := s.ensureMSSClamping(); err != nil {
|
|
||||||
// logger.Warn("Failed to ensure MSS clamping: %v", err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) createWireGuardInterface() error {
|
|
||||||
wgLink := &netlink.GenericLink{
|
|
||||||
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
|
|
||||||
LinkType: "wireguard",
|
|
||||||
}
|
|
||||||
return netlink.LinkAdd(wgLink)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
|
|
||||||
link, err := netlink.LinkByName(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := netlink.ParseAddr(ipAddress)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return netlink.AddrAdd(link, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
|
||||||
// get the current peers
|
|
||||||
device, err := wgClient.Device(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the peer public keys
|
|
||||||
var currentPeers []string
|
|
||||||
for _, peer := range device.Peers {
|
|
||||||
currentPeers = append(currentPeers, peer.PublicKey.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove any peers that are not in the config
|
|
||||||
for _, peer := range currentPeers {
|
|
||||||
found := false
|
|
||||||
for _, configPeer := range peers {
|
|
||||||
if peer == configPeer.PublicKey {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
err := s.removePeer(peer)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove peer: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// add any peers that are in the config but not in the current peers
|
|
||||||
for _, configPeer := range peers {
|
|
||||||
found := false
|
|
||||||
for _, peer := range currentPeers {
|
|
||||||
if configPeer.PublicKey == peer {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
err := s.addPeer(configPeer)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add peer: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// func (s *WireGuardService) ensureMSSClamping() error {
|
|
||||||
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
|
|
||||||
// mssValue := mtuInt - 40
|
|
||||||
|
|
||||||
// // Rules to be managed - just the chains, we'll construct the full command separately
|
|
||||||
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
|
|
||||||
|
|
||||||
// // First, try to delete any existing rules
|
|
||||||
// for _, chain := range chains {
|
|
||||||
// deleteCmd := exec.Command("/usr/sbin/iptables",
|
|
||||||
// "-t", "mangle",
|
|
||||||
// "-D", chain,
|
|
||||||
// "-p", "tcp",
|
|
||||||
// "--tcp-flags", "SYN,RST", "SYN",
|
|
||||||
// "-j", "TCPMSS",
|
|
||||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
|
||||||
|
|
||||||
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
|
|
||||||
|
|
||||||
// // Try deletion multiple times to handle multiple existing rules
|
|
||||||
// for i := 0; i < 3; i++ {
|
|
||||||
// out, err := deleteCmd.CombinedOutput()
|
|
||||||
// if err != nil {
|
|
||||||
// // Convert exit status 1 to string for better logging
|
|
||||||
// if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
|
|
||||||
// chain, exitErr.String(), string(out))
|
|
||||||
// }
|
|
||||||
// break // No more rules to delete
|
|
||||||
// }
|
|
||||||
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Then add the new rules
|
|
||||||
// var errors []error
|
|
||||||
// for _, chain := range chains {
|
|
||||||
// addCmd := exec.Command("/usr/sbin/iptables",
|
|
||||||
// "-t", "mangle",
|
|
||||||
// "-A", chain,
|
|
||||||
// "-p", "tcp",
|
|
||||||
// "--tcp-flags", "SYN,RST", "SYN",
|
|
||||||
// "-j", "TCPMSS",
|
|
||||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
|
||||||
|
|
||||||
// logger.Info("Adding MSS clamping rule for chain %s", chain)
|
|
||||||
|
|
||||||
// if out, err := addCmd.CombinedOutput(); err != nil {
|
|
||||||
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
|
||||||
// chain, err, string(out))
|
|
||||||
// logger.Error(errMsg)
|
|
||||||
// errors = append(errors, fmt.Errorf(errMsg))
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // Verify the rule was added
|
|
||||||
// checkCmd := exec.Command("/usr/sbin/iptables",
|
|
||||||
// "-t", "mangle",
|
|
||||||
// "-C", chain,
|
|
||||||
// "-p", "tcp",
|
|
||||||
// "--tcp-flags", "SYN,RST", "SYN",
|
|
||||||
// "-j", "TCPMSS",
|
|
||||||
// "--set-mss", fmt.Sprintf("%d", mssValue))
|
|
||||||
|
|
||||||
// if out, err := checkCmd.CombinedOutput(); err != nil {
|
|
||||||
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
|
||||||
// chain, err, string(out))
|
|
||||||
// logger.Error(errMsg)
|
|
||||||
// errors = append(errors, fmt.Errorf(errMsg))
|
|
||||||
// continue
|
|
||||||
// }
|
|
||||||
|
|
||||||
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // If we encountered any errors, return them combined
|
|
||||||
// if len(errors) > 0 {
|
|
||||||
// var errMsgs []string
|
|
||||||
// for _, err := range errors {
|
|
||||||
// errMsgs = append(errMsgs, err.Error())
|
|
||||||
// }
|
|
||||||
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
|
|
||||||
// strings.Join(errMsgs, "\n"))
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
|
|
||||||
var peer Peer
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &peer); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.addPeer(peer)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) addPeer(peer Peer) error {
|
|
||||||
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse allowed IPs into array of net.IPNet
|
|
||||||
var allowedIPs []net.IPNet
|
|
||||||
for _, ipStr := range peer.AllowedIPs {
|
|
||||||
_, ipNet, err := net.ParseCIDR(ipStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse allowed IP: %v", err)
|
|
||||||
}
|
|
||||||
allowedIPs = append(allowedIPs, *ipNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerConfig := wgtypes.PeerConfig{
|
|
||||||
PublicKey: pubKey,
|
|
||||||
AllowedIPs: allowedIPs,
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
|
||||||
return fmt.Errorf("failed to add peer: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
|
|
||||||
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
|
|
||||||
type RemoveRequest struct {
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var request RemoveRequest
|
|
||||||
if err := json.Unmarshal(jsonData, &request); err != nil {
|
|
||||||
logger.Info("Error unmarshaling data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.removePeer(request.PublicKey); err != nil {
|
|
||||||
logger.Info("Error removing peer: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) removePeer(publicKey string) error {
|
|
||||||
pubKey, err := wgtypes.ParseKey(publicKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse public key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerConfig := wgtypes.PeerConfig{
|
|
||||||
PublicKey: pubKey,
|
|
||||||
Remove: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peerConfig},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
|
|
||||||
return fmt.Errorf("failed to remove peer: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Peer %s removed successfully", publicKey)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) periodicBandwidthCheck() {
|
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
if err := s.reportPeerBandwidth(); err != nil {
|
|
||||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
|
||||||
device, err := wgClient.Device(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerBandwidths := []PeerBandwidth{}
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
for _, peer := range device.Peers {
|
|
||||||
publicKey := peer.PublicKey.String()
|
|
||||||
currentReading := PeerReading{
|
|
||||||
BytesReceived: peer.ReceiveBytes,
|
|
||||||
BytesTransmitted: peer.TransmitBytes,
|
|
||||||
LastChecked: now,
|
|
||||||
}
|
|
||||||
|
|
||||||
var bytesInDiff, bytesOutDiff float64
|
|
||||||
lastReading, exists := lastReadings[publicKey]
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
|
|
||||||
if timeDiff > 0 {
|
|
||||||
// Calculate bytes transferred since last reading
|
|
||||||
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
|
|
||||||
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
|
|
||||||
|
|
||||||
// Handle counter wraparound (if the counter resets or overflows)
|
|
||||||
if bytesInDiff < 0 {
|
|
||||||
bytesInDiff = float64(currentReading.BytesReceived)
|
|
||||||
}
|
|
||||||
if bytesOutDiff < 0 {
|
|
||||||
bytesOutDiff = float64(currentReading.BytesTransmitted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to MB
|
|
||||||
bytesInMB := bytesInDiff / (1024 * 1024)
|
|
||||||
bytesOutMB := bytesOutDiff / (1024 * 1024)
|
|
||||||
|
|
||||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
|
||||||
PublicKey: publicKey,
|
|
||||||
BytesIn: bytesInMB,
|
|
||||||
BytesOut: bytesOutMB,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// If readings are too close together or time hasn't passed, report 0
|
|
||||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
|
||||||
PublicKey: publicKey,
|
|
||||||
BytesIn: 0,
|
|
||||||
BytesOut: 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// For first reading of a peer, report 0 to establish baseline
|
|
||||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
|
||||||
PublicKey: publicKey,
|
|
||||||
BytesIn: 0,
|
|
||||||
BytesOut: 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the last reading
|
|
||||||
lastReadings[publicKey] = currentReading
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up old peers
|
|
||||||
for publicKey := range lastReadings {
|
|
||||||
found := false
|
|
||||||
for _, peer := range device.Peers {
|
|
||||||
if peer.PublicKey.String() == publicKey {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
delete(lastReadings, publicKey)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return peerBandwidths, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *WireGuardService) reportPeerBandwidth() error {
|
|
||||||
bandwidths, err := s.calculatePeerBandwidth()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(bandwidths)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.client.SendMessage("wg/bandwidth", jsonData)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user