Replace netstack and remove proxy

This commit is contained in:
Owen
2025-02-20 21:24:39 -05:00
parent 66edae4288
commit 6107d20e26
3 changed files with 100 additions and 1180 deletions

322
main.go
View File

@@ -1,7 +1,6 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -9,7 +8,6 @@ import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"os/signal"
"strconv"
@@ -18,16 +16,15 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/wg"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -62,59 +59,93 @@ func fixKey(key string) string {
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) error {
logger.Info("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
const (
ENV_WG_TUN_FD = "WG_TUN_FD"
ENV_WG_UAPI_FD = "WG_UAPI_FD"
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
)
func ping(dev *device.Device, dst string) error {
logger.Info("Pinging %s over WireGuard tunnel", dst)
// Create a raw socket for ICMP
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
return fmt.Errorf("failed to create ICMP socket: %w", err)
}
defer socket.Close()
defer conn.Close()
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
// Parse destination IP
dstIP := net.ParseIP(dst)
if dstIP == nil {
return fmt.Errorf("invalid destination IP: %s", dst)
}
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
// Create ICMP message
requestPing := icmp.Echo{
ID: os.Getpid() & 0xffff,
Seq: rand.Intn(1 << 16),
Data: []byte("wireguard ping"),
}
msg := icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &requestPing,
}
// Marshal the message
icmpBytes, err := msg.Marshal(nil)
if err != nil {
return fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
// Set read deadline
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
return fmt.Errorf("failed to set read deadline: %w", err)
}
// Send the ping
start := time.Now()
_, err = socket.Write(icmpBytes)
_, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP})
if err != nil {
return fmt.Errorf("failed to write ICMP packet: %w", err)
}
n, err := socket.Read(icmpBytes[:])
// Wait for reply
reply := make([]byte, 1500)
n, peer, err := conn.ReadFrom(reply)
if err != nil {
return fmt.Errorf("failed to read ICMP packet: %w", err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
// Parse reply
replyMsg, err := icmp.ParseMessage(1, reply[:n])
if err != nil {
return fmt.Errorf("failed to parse ICMP packet: %w", err)
return fmt.Errorf("failed to parse ICMP reply: %w", err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
// Verify reply
switch replyMsg.Type {
case ipv4.ICMPTypeEchoReply:
replyEcho, ok := replyMsg.Body.(*icmp.Echo)
if !ok {
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body)
}
if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq {
return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d",
replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq)
}
default:
return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
}
logger.Info("Ping latency: %v", time.Since(start))
duration := time.Since(start)
logger.Info("Ping reply from %v: time=%v", peer, duration)
return nil
}
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
func startPingCheck(dev *device.Device, serverIP string, stopChan chan struct{}) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
@@ -122,10 +153,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
for {
select {
case <-ticker.C:
err := ping(tnet, serverIP)
err := ping(dev, serverIP)
if err != nil {
logger.Warn("Periodic ping failed: %v", err)
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable")
}
case <-stopChan:
logger.Info("Stopping ping check")
@@ -135,7 +166,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
}()
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
func pingWithRetry(dev *device.Device, dst string) error {
const (
maxAttempts = 5
retryDelay = 2 * time.Second
@@ -145,7 +176,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
for attempt := 1; attempt <= maxAttempts; attempt++ {
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
if err := ping(tnet, dst); err != nil {
if err := ping(dev, dst); err != nil {
lastErr = err
logger.Warn("Ping attempt %d failed: %v", attempt, err)
@@ -161,7 +192,6 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
return nil
}
// This shouldn't be reached due to the return in the loop, but added for completeness
return fmt.Errorf("unexpected error: all ping attempts failed")
}
@@ -335,29 +365,13 @@ func main() {
logger.Fatal("Failed to create client: %v", err)
}
// Create WireGuard service
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
if err != nil {
logger.Fatal("Failed to create WireGuard service: %v", err)
}
defer wgService.Close()
// Create TUN device and network stack
var tun tun.Device
var tnet *netstack.Net
var dev *device.Device
var pm *proxy.ProxyManager
var connected bool
var wgData WgData
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
client.RegisterHandler("client/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
if pm != nil {
pm.Stop()
}
if dev != nil {
dev.Close()
}
client.Close()
})
@@ -365,13 +379,12 @@ func main() {
defer close(pingStopChan)
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
client.RegisterHandler("client/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message")
if connected {
logger.Info("Already connected! But I will send a ping anyway...")
// ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
err := pingWithRetry(dev, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
@@ -391,17 +404,39 @@ func main() {
return
}
logger.Info("Received: %+v", msg)
tun, tnet, err = netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
[]netip.Addr{netip.MustParseAddr(dns)},
mtuInt)
if err != nil {
logger.Error("Failed to create TUN device: %v", err)
}
// logger.Info("Received: %+v", msg)
// tun, tnet, err = netstack.CreateNetTUN(
// []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
// []netip.Addr{netip.MustParseAddr(dns)},
// mtuInt)
// if err != nil {
// logger.Error("Failed to create TUN device: %v", err)
// }
tdev, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return tun.CreateTUN(interfaceName, mtuInt)
}
// construct tun device from supplied fd
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
err = unix.SetNonblock(int(fd), true)
if err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "")
return tun.CreateTUNFromFile(file, mtuInt)
}()
// Create WireGuard device
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
@@ -433,7 +468,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
logger.Info("WireGuard device created. Lets ping the server now...")
// Ping to bring the tunnel up on the server side quickly
// ping(tnet, wgData.ServerIP)
err = pingWithRetry(tnet, wgData.ServerIP)
err = pingWithRetry(dev, wgData.ServerIP)
if err != nil {
// Handle complete failure after all retries
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
@@ -441,114 +476,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
if !connected {
logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
startPingCheck(dev, wgData.ServerIP, pingStopChan)
}
// Create proxy manager
pm = proxy.NewProxyManager(tnet)
connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
err = pm.Start()
if err != nil {
logger.Error("Failed to start proxy manager: %v", err)
}
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
client.OnConnect(func() error {
publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{
err := client.SendMessage("client/wg/register", map[string]interface{}{
"publicKey": fmt.Sprintf("%s", publicKey),
})
if err != nil {
@@ -574,62 +511,3 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
// Cleanup
dev.Close()
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
}
}
// Add the new target
pm.AddTarget(proto, tunnelIP, port, target)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
}
}
}
return nil
}