Replace netstack and remove proxy

This commit is contained in:
Owen
2025-02-20 21:24:39 -05:00
parent 66edae4288
commit 6107d20e26
3 changed files with 100 additions and 1180 deletions

322
main.go
View File

@@ -1,7 +1,6 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -9,7 +8,6 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"os/signal"
"strconv"
@@ -18,16 +16,15 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -62,59 +59,93 @@ func fixKey(key string) string {
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) error {
logger.Info("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
const (
ENV_WG_TUN_FD = "WG_TUN_FD"
ENV_WG_UAPI_FD = "WG_UAPI_FD"
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
)
func ping(dev *device.Device, dst string) error {
logger.Info("Pinging %s over WireGuard tunnel", dst)
// Create a raw socket for ICMP
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return fmt.Errorf("failed to create ICMP socket: %w", err)
}
defer socket.Close()
defer conn.Close()
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
// Parse destination IP
dstIP := net.ParseIP(dst)
if dstIP == nil {
return fmt.Errorf("invalid destination IP: %s", dst)
}
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
// Create ICMP message
requestPing := icmp.Echo{
ID: os.Getpid() & 0xffff,
Seq: rand.Intn(1 << 16),
Data: []byte("wireguard ping"),
}
msg := icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &requestPing,
}
// Marshal the message
icmpBytes, err := msg.Marshal(nil)
if err != nil {
return fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
// Set read deadline
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
return fmt.Errorf("failed to set read deadline: %w", err)
}
// Send the ping
start := time.Now()
_, err = socket.Write(icmpBytes)
_, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP})
if err != nil {
return fmt.Errorf("failed to write ICMP packet: %w", err)
}
n, err := socket.Read(icmpBytes[:])
// Wait for reply
reply := make([]byte, 1500)
n, peer, err := conn.ReadFrom(reply)
if err != nil {
return fmt.Errorf("failed to read ICMP packet: %w", err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
// Parse reply
replyMsg, err := icmp.ParseMessage(1, reply[:n])
if err != nil {
return fmt.Errorf("failed to parse ICMP packet: %w", err)
return fmt.Errorf("failed to parse ICMP reply: %w", err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
// Verify reply
switch replyMsg.Type {
case ipv4.ICMPTypeEchoReply:
replyEcho, ok := replyMsg.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body)
}
if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq {
return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d",
replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq)
}
default:
return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
}
logger.Info("Ping latency: %v", time.Since(start))
duration := time.Since(start)
logger.Info("Ping reply from %v: time=%v", peer, duration)
return nil
}
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
func startPingCheck(dev *device.Device, serverIP string, stopChan chan struct{}) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
@@ -122,10 +153,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
for {
select {
case <-ticker.C:
err := ping(tnet, serverIP)
err := ping(dev, serverIP)
if err != nil {
logger.Warn("Periodic ping failed: %v", err)
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable")
}
case <-stopChan:
logger.Info("Stopping ping check")
@@ -135,7 +166,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
}()
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
func pingWithRetry(dev *device.Device, dst string) error {
const (
maxAttempts = 5
retryDelay = 2 * time.Second
@@ -145,7 +176,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
for attempt := 1; attempt <= maxAttempts; attempt++ {
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
if err := ping(tnet, dst); err != nil {
if err := ping(dev, dst); err != nil {
lastErr = err
logger.Warn("Ping attempt %d failed: %v", attempt, err)
@@ -161,7 +192,6 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
return nil
}
// This shouldn't be reached due to the return in the loop, but added for completeness
return fmt.Errorf("unexpected error: all ping attempts failed")
}
@@ -335,29 +365,13 @@ func main() {
logger.Fatal("Failed to create client: %v", err)
}
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
var dev *device.Device
var pm *proxy.ProxyManager
var connected bool
var wgData WgData
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
client.RegisterHandler("client/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
if pm != nil {
pm.Stop()
}
if dev != nil {
dev.Close()
}
client.Close()
})
@@ -365,13 +379,12 @@ func main() {
defer close(pingStopChan)
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
client.RegisterHandler("client/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message")
if connected {
logger.Info("Already connected! But I will send a ping anyway...")
// ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
err := pingWithRetry(dev, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
@@ -391,17 +404,39 @@ func main() {
return
}
logger.Info("Received: %+v", msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
mtuInt)
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
}
// logger.Info("Received: %+v", msg)
// tun, tnet, err = netstack.CreateNetTUN(
// []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
// []netip.Addr{netip.MustParseAddr(dns)},
// mtuInt)
// if err != nil {
// logger.Error("Failed to create TUN device: %v", err)
// }
tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, mtuInt)
}
// construct tun device from supplied fd
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "")
return tun.CreateTUNFromFile(file, mtuInt)
}()
// Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
@@ -433,7 +468,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
logger.Info("WireGuard device created. Lets ping the server now...")
// Ping to bring the tunnel up on the server side quickly
// ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
err = pingWithRetry(dev, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
@@ -441,114 +476,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
if !connected {
logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
startPingCheck(dev, wgData.ServerIP, pingStopChan)
}
// Create proxy manager
pm = proxy.NewProxyManager(tnet)
connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
}
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
client.OnConnect(func() error {
publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{
err := client.SendMessage("client/wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", publicKey),
})
if err != nil {
@@ -574,62 +511,3 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
// Cleanup
dev.Close()
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
}
}
// Add the new target
pm.AddTarget(proto, tunnelIP, port, target)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
}
}
}
return nil
}

View File

@@ -1,352 +0,0 @@
package proxy
import (
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
// Target represents a proxy target with its address and port
type Target struct {
Address string
Port int
}
// ProxyManager handles the creation and management of proxy connections
type ProxyManager struct {
tnet *netstack.Net
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
udpTargets map[string]map[int]string
listeners []*gonet.TCPListener
udpConns []*gonet.UDPConn
running bool
mutex sync.RWMutex
}
// NewProxyManager creates a new proxy manager instance
func NewProxyManager(tnet *netstack.Net) *ProxyManager {
return &ProxyManager{
tnet: tnet,
tcpTargets: make(map[string]map[int]string),
udpTargets: make(map[string]map[int]string),
listeners: make([]*gonet.TCPListener, 0),
udpConns: make([]*gonet.UDPConn, 0),
}
}
// AddTarget adds as new target for proxying
func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
switch proto {
case "tcp":
if pm.tcpTargets[listenIP] == nil {
pm.tcpTargets[listenIP] = make(map[int]string)
}
pm.tcpTargets[listenIP][port] = targetAddr
case "udp":
if pm.udpTargets[listenIP] == nil {
pm.udpTargets[listenIP] = make(map[int]string)
}
pm.udpTargets[listenIP][port] = targetAddr
default:
return fmt.Errorf("unsupported protocol: %s", proto)
}
if pm.running {
return pm.startTarget(proto, listenIP, port, targetAddr)
} else {
logger.Debug("Not adding target because not running")
}
return nil
}
func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
switch proto {
case "tcp":
if targets, ok := pm.tcpTargets[listenIP]; ok {
delete(targets, port)
// Remove and close the corresponding TCP listener
for i, listener := range pm.listeners {
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
listener.Close()
time.Sleep(50 * time.Millisecond)
// Remove from slice
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
break
}
}
} else {
return fmt.Errorf("target not found: %s:%d", listenIP, port)
}
case "udp":
if targets, ok := pm.udpTargets[listenIP]; ok {
delete(targets, port)
// Remove and close the corresponding UDP connection
for i, conn := range pm.udpConns {
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
conn.Close()
time.Sleep(50 * time.Millisecond)
// Remove from slice
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
break
}
}
} else {
return fmt.Errorf("target not found: %s:%d", listenIP, port)
}
default:
return fmt.Errorf("unsupported protocol: %s", proto)
}
return nil
}
// Start begins listening for all configured proxy targets
func (pm *ProxyManager) Start() error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if pm.running {
return nil
}
// Start TCP targets
for listenIP, targets := range pm.tcpTargets {
for port, targetAddr := range targets {
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
return fmt.Errorf("failed to start TCP target: %v", err)
}
}
}
// Start UDP targets
for listenIP, targets := range pm.udpTargets {
for port, targetAddr := range targets {
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
return fmt.Errorf("failed to start UDP target: %v", err)
}
}
}
pm.running = true
return nil
}
func (pm *ProxyManager) Stop() error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if !pm.running {
return nil
}
// Set running to false first to signal handlers to stop
pm.running = false
// Close TCP listeners
for i := len(pm.listeners) - 1; i >= 0; i-- {
listener := pm.listeners[i]
if err := listener.Close(); err != nil {
logger.Error("Error closing TCP listener: %v", err)
}
// Remove from slice
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
}
// Close UDP connections
for i := len(pm.udpConns) - 1; i >= 0; i-- {
conn := pm.udpConns[i]
if err := conn.Close(); err != nil {
logger.Error("Error closing UDP connection: %v", err)
}
// Remove from slice
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
// Clear the target maps
for k := range pm.tcpTargets {
delete(pm.tcpTargets, k)
}
for k := range pm.udpTargets {
delete(pm.udpTargets, k)
}
// Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond)
return nil
}
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
switch proto {
case "tcp":
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
if err != nil {
return fmt.Errorf("failed to create TCP listener: %v", err)
}
pm.listeners = append(pm.listeners, listener)
go pm.handleTCPProxy(listener, targetAddr)
case "udp":
addr := &net.UDPAddr{Port: port}
conn, err := pm.tnet.ListenUDP(addr)
if err != nil {
return fmt.Errorf("failed to create UDP listener: %v", err)
}
pm.udpConns = append(pm.udpConns, conn)
go pm.handleUDPProxy(conn, targetAddr)
default:
return fmt.Errorf("unsupported protocol: %s", proto)
}
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
return nil
}
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
for {
conn, err := listener.Accept()
if err != nil {
// Check if we're shutting down or the listener was closed
if !pm.running {
return
}
// Check for specific network errors that indicate the listener is closed
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
logger.Error("Error accepting TCP connection: %v", err)
// Don't hammer the CPU if we hit a temporary error
time.Sleep(100 * time.Millisecond)
continue
}
go func() {
target, err := net.Dial("tcp", targetAddr)
if err != nil {
logger.Error("Error connecting to target: %v", err)
conn.Close()
return
}
// Create a WaitGroup to ensure both copy operations complete
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
io.Copy(target, conn)
target.Close()
}()
go func() {
defer wg.Done()
io.Copy(conn, target)
conn.Close()
}()
// Wait for both copies to complete
wg.Wait()
}()
}
}
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, 65507) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex
for {
n, remoteAddr, err := conn.ReadFrom(buffer)
if err != nil {
if !pm.running {
return
}
// Check for connection closed conditions
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections
clientsMutex.Lock()
for _, targetConn := range clientConns {
targetConn.Close()
}
clientConns = nil
clientsMutex.Unlock()
return
}
logger.Error("Error reading UDP packet: %v", err)
continue
}
clientKey := remoteAddr.String()
clientsMutex.RLock()
targetConn, exists := clientConns[clientKey]
clientsMutex.RUnlock()
if !exists {
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {
logger.Error("Error resolving target address: %v", err)
continue
}
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
if err != nil {
logger.Error("Error connecting to target: %v", err)
continue
}
clientsMutex.Lock()
clientConns[clientKey] = targetConn
clientsMutex.Unlock()
go func() {
buffer := make([]byte, 65507)
for {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
logger.Error("Error reading from target: %v", err)
return
}
_, err = conn.WriteTo(buffer[:n], remoteAddr)
if err != nil {
logger.Error("Error writing to client: %v", err)
return
}
}
}()
}
_, err = targetConn.Write(buffer[:n])
if err != nil {
logger.Error("Error writing to target: %v", err)
targetConn.Close()
clientsMutex.Lock()
delete(clientConns, clientKey)
clientsMutex.Unlock()
}
}
}

606
wg/wg.go
View File

@@ -1,606 +0,0 @@
package wg
import (
"bytes"
"encoding/json"
"fmt"
"net"
"os"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/websocket"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var (
interfaceName string
listenAddr string
mtuInt int
lastReadings = make(map[string]PeerReading)
mu sync.Mutex
)
type WgConfig struct {
PrivateKey string `json:"privateKey"`
ListenPort int `json:"listenPort"`
IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"`
}
type Peer struct {
PublicKey string `json:"publicKey"`
AllowedIPs []string `json:"allowedIps"`
}
type PeerBandwidth struct {
PublicKey string `json:"publicKey"`
BytesIn float64 `json:"bytesIn"`
BytesOut float64 `json:"bytesOut"`
}
type PeerReading struct {
BytesReceived int64
BytesTransmitted int64
LastChecked time.Time
}
var (
wgClient *wgctrl.Client
)
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
wgClient *wgctrl.Client
config WgConfig
key wgtypes.Key
reachableAt string
lastReadings map[string]PeerReading
mu sync.Mutex
}
func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) {
wgClient, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
}
key := wgtypes.Key{}
// if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file
if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) {
// generate a new private key
key, err = wgtypes.GeneratePrivateKey()
if err != nil {
logger.Fatal("Failed to generate private key: %v", err)
}
// save the key to the file
err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644)
if err != nil {
logger.Fatal("Failed to save private key: %v", err)
}
} else {
keyData, err := os.ReadFile(generateAndSaveKeyTo)
if err != nil {
logger.Fatal("Failed to read private key: %v", err)
}
key, err = wgtypes.ParseKey(string(keyData))
if err != nil {
logger.Fatal("Failed to parse private key: %v", err)
}
}
service := &WireGuardService{
interfaceName: interfaceName,
mtu: mtu,
client: wsClient,
wgClient: wgClient,
key: key,
reachableAt: reachableAt,
lastReadings: make(map[string]PeerReading),
}
// Register websocket handlers
wsClient.RegisterHandler("wg/config/receive", service.handleConfig)
wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer)
wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer)
// Register connect handler to initiate configuration
wsClient.OnConnect(service.loadRemoteConfig)
return service, nil
}
func (s *WireGuardService) Close() {
s.client.Close()
wgClient.Close()
}
func (s *WireGuardService) loadRemoteConfig() error {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt)))
go s.periodicBandwidthCheck()
err := s.client.SendMessage("wg/config/get", body)
if err != nil {
return fmt.Errorf("failed to send config request: %v", err)
}
return nil
}
func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
var config WgConfig
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &config); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
s.config = config
// Ensure the WireGuard interface and peers are configured
if err := s.ensureWireguardInterface(config); err != nil {
logger.Error("Failed to ensure WireGuard interface: %v", err)
}
if err := s.ensureWireguardPeers(config.Peers); err != nil {
logger.Error("Failed to ensure WireGuard peers: %v", err)
}
}
func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Check if the WireGuard interface exists
_, err := netlink.LinkByName(interfaceName)
if err != nil {
if _, ok := err.(netlink.LinkNotFoundError); ok {
// Interface doesn't exist, so create it
err = s.createWireGuardInterface()
if err != nil {
logger.Fatal("Failed to create WireGuard interface: %v", err)
}
logger.Info("Created WireGuard interface %s\n", interfaceName)
} else {
logger.Fatal("Error checking for WireGuard interface: %v", err)
}
} else {
logger.Info("WireGuard interface %s already exists\n", interfaceName)
return nil
}
// Assign IP address to the interface
err = s.assignIPAddress(wgconfig.IpAddress)
if err != nil {
logger.Fatal("Failed to assign IP address: %v", err)
}
logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName)
// Check if the interface already exists
_, err = wgClient.Device(interfaceName)
if err != nil {
return fmt.Errorf("interface %s does not exist", interfaceName)
}
// Parse the private key
key, err := wgtypes.ParseKey(wgconfig.PrivateKey)
if err != nil {
return fmt.Errorf("failed to parse private key: %v", err)
}
// Create a new WireGuard configuration
config := wgtypes.Config{
PrivateKey: &key,
ListenPort: new(int),
}
*config.ListenPort = wgconfig.ListenPort
// Create and configure the WireGuard interface
err = wgClient.ConfigureDevice(interfaceName, config)
if err != nil {
return fmt.Errorf("failed to configure WireGuard device: %v", err)
}
// bring up the interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
if err := netlink.LinkSetMTU(link, mtuInt); err != nil {
return fmt.Errorf("failed to set MTU: %v", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
// if err := s.ensureMSSClamping(); err != nil {
// logger.Warn("Failed to ensure MSS clamping: %v", err)
// }
logger.Info("WireGuard interface %s created and configured", interfaceName)
return nil
}
func (s *WireGuardService) createWireGuardInterface() error {
wgLink := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{Name: interfaceName},
LinkType: "wireguard",
}
return netlink.LinkAdd(wgLink)
}
func (s *WireGuardService) assignIPAddress(ipAddress string) error {
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface: %v", err)
}
addr, err := netlink.ParseAddr(ipAddress)
if err != nil {
return fmt.Errorf("failed to parse IP address: %v", err)
}
return netlink.AddrAdd(link, addr)
}
func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
// get the current peers
device, err := wgClient.Device(interfaceName)
if err != nil {
return fmt.Errorf("failed to get device: %v", err)
}
// get the peer public keys
var currentPeers []string
for _, peer := range device.Peers {
currentPeers = append(currentPeers, peer.PublicKey.String())
}
// remove any peers that are not in the config
for _, peer := range currentPeers {
found := false
for _, configPeer := range peers {
if peer == configPeer.PublicKey {
found = true
break
}
}
if !found {
err := s.removePeer(peer)
if err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
}
}
// add any peers that are in the config but not in the current peers
for _, configPeer := range peers {
found := false
for _, peer := range currentPeers {
if configPeer.PublicKey == peer {
found = true
break
}
}
if !found {
err := s.addPeer(configPeer)
if err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
}
}
return nil
}
// func (s *WireGuardService) ensureMSSClamping() error {
// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20))
// mssValue := mtuInt - 40
// // Rules to be managed - just the chains, we'll construct the full command separately
// chains := []string{"INPUT", "OUTPUT", "FORWARD"}
// // First, try to delete any existing rules
// for _, chain := range chains {
// deleteCmd := exec.Command("/usr/sbin/iptables",
// "-t", "mangle",
// "-D", chain,
// "-p", "tcp",
// "--tcp-flags", "SYN,RST", "SYN",
// "-j", "TCPMSS",
// "--set-mss", fmt.Sprintf("%d", mssValue))
// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain)
// // Try deletion multiple times to handle multiple existing rules
// for i := 0; i < 3; i++ {
// out, err := deleteCmd.CombinedOutput()
// if err != nil {
// // Convert exit status 1 to string for better logging
// if exitErr, ok := err.(*exec.ExitError); ok {
// logger.Debug("Deletion stopped for chain %s: %v (output: %s)",
// chain, exitErr.String(), string(out))
// }
// break // No more rules to delete
// }
// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1)
// }
// }
// // Then add the new rules
// var errors []error
// for _, chain := range chains {
// addCmd := exec.Command("/usr/sbin/iptables",
// "-t", "mangle",
// "-A", chain,
// "-p", "tcp",
// "--tcp-flags", "SYN,RST", "SYN",
// "-j", "TCPMSS",
// "--set-mss", fmt.Sprintf("%d", mssValue))
// logger.Info("Adding MSS clamping rule for chain %s", chain)
// if out, err := addCmd.CombinedOutput(); err != nil {
// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
// chain, err, string(out))
// logger.Error(errMsg)
// errors = append(errors, fmt.Errorf(errMsg))
// continue
// }
// // Verify the rule was added
// checkCmd := exec.Command("/usr/sbin/iptables",
// "-t", "mangle",
// "-C", chain,
// "-p", "tcp",
// "--tcp-flags", "SYN,RST", "SYN",
// "-j", "TCPMSS",
// "--set-mss", fmt.Sprintf("%d", mssValue))
// if out, err := checkCmd.CombinedOutput(); err != nil {
// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
// chain, err, string(out))
// logger.Error(errMsg)
// errors = append(errors, fmt.Errorf(errMsg))
// continue
// }
// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain)
// }
// // If we encountered any errors, return them combined
// if len(errors) > 0 {
// var errMsgs []string
// for _, err := range errors {
// errMsgs = append(errMsgs, err.Error())
// }
// return fmt.Errorf("MSS clamping setup encountered errors:\n%s",
// strings.Join(errMsgs, "\n"))
// }
// return nil
// }
func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
var peer Peer
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
if err := json.Unmarshal(jsonData, &peer); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
}
err = s.addPeer(peer)
if err != nil {
return
}
}
func (s *WireGuardService) addPeer(peer Peer) error {
pubKey, err := wgtypes.ParseKey(peer.PublicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
// parse allowed IPs into array of net.IPNet
var allowedIPs []net.IPNet
for _, ipStr := range peer.AllowedIPs {
_, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return fmt.Errorf("failed to parse allowed IP: %v", err)
}
allowedIPs = append(allowedIPs, *ipNet)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
AllowedIPs: allowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
return fmt.Errorf("failed to add peer: %v", err)
}
logger.Info("Peer %s added successfully", peer.PublicKey)
return nil
}
func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) {
// parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" }
type RemoveRequest struct {
PublicKey string `json:"publicKey"`
}
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
}
var request RemoveRequest
if err := json.Unmarshal(jsonData, &request); err != nil {
logger.Info("Error unmarshaling data: %v", err)
return
}
if err := s.removePeer(request.PublicKey); err != nil {
logger.Info("Error removing peer: %v", err)
return
}
}
func (s *WireGuardService) removePeer(publicKey string) error {
pubKey, err := wgtypes.ParseKey(publicKey)
if err != nil {
return fmt.Errorf("failed to parse public key: %v", err)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
if err := wgClient.ConfigureDevice(interfaceName, config); err != nil {
return fmt.Errorf("failed to remove peer: %v", err)
}
logger.Info("Peer %s removed successfully", publicKey)
return nil
}
func (s *WireGuardService) periodicBandwidthCheck() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := s.reportPeerBandwidth(); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
}
}
func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
device, err := wgClient.Device(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get device: %v", err)
}
peerBandwidths := []PeerBandwidth{}
now := time.Now()
mu.Lock()
defer mu.Unlock()
for _, peer := range device.Peers {
publicKey := peer.PublicKey.String()
currentReading := PeerReading{
BytesReceived: peer.ReceiveBytes,
BytesTransmitted: peer.TransmitBytes,
LastChecked: now,
}
var bytesInDiff, bytesOutDiff float64
lastReading, exists := lastReadings[publicKey]
if exists {
timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds()
if timeDiff > 0 {
// Calculate bytes transferred since last reading
bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived)
bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted)
// Handle counter wraparound (if the counter resets or overflows)
if bytesInDiff < 0 {
bytesInDiff = float64(currentReading.BytesReceived)
}
if bytesOutDiff < 0 {
bytesOutDiff = float64(currentReading.BytesTransmitted)
}
// Convert to MB
bytesInMB := bytesInDiff / (1024 * 1024)
bytesOutMB := bytesOutDiff / (1024 * 1024)
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: bytesInMB,
BytesOut: bytesOutMB,
})
} else {
// If readings are too close together or time hasn't passed, report 0
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
} else {
// For first reading of a peer, report 0 to establish baseline
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: 0,
BytesOut: 0,
})
}
// Update the last reading
lastReadings[publicKey] = currentReading
}
// Clean up old peers
for publicKey := range lastReadings {
found := false
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
found = true
break
}
}
if !found {
delete(lastReadings, publicKey)
}
}
return peerBandwidths, nil
}
func (s *WireGuardService) reportPeerBandwidth() error {
bandwidths, err := s.calculatePeerBandwidth()
if err != nil {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
jsonData, err := json.Marshal(bandwidths)
if err != nil {
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
}
err = s.client.SendMessage("wg/bandwidth", jsonData)
if err != nil {
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
return nil
}