mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Replace netstack and remove proxy
This commit is contained in:
322
main.go
322
main.go
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
@@ -18,16 +16,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"github.com/fosrl/newt/wg"
|
||||
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@@ -62,59 +59,93 @@ func fixKey(key string) string {
|
||||
return hex.EncodeToString(decoded)
|
||||
}
|
||||
|
||||
func ping(tnet *netstack.Net, dst string) error {
|
||||
logger.Info("Pinging %s", dst)
|
||||
socket, err := tnet.Dial("ping4", dst)
|
||||
const (
|
||||
ENV_WG_TUN_FD = "WG_TUN_FD"
|
||||
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
||||
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
||||
)
|
||||
|
||||
func ping(dev *device.Device, dst string) error {
|
||||
logger.Info("Pinging %s over WireGuard tunnel", dst)
|
||||
|
||||
// Create a raw socket for ICMP
|
||||
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ICMP socket: %w", err)
|
||||
}
|
||||
defer socket.Close()
|
||||
defer conn.Close()
|
||||
|
||||
requestPing := icmp.Echo{
|
||||
Seq: rand.Intn(1 << 16),
|
||||
Data: []byte("gopher burrow"),
|
||||
// Parse destination IP
|
||||
dstIP := net.ParseIP(dst)
|
||||
if dstIP == nil {
|
||||
return fmt.Errorf("invalid destination IP: %s", dst)
|
||||
}
|
||||
|
||||
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
||||
// Create ICMP message
|
||||
requestPing := icmp.Echo{
|
||||
ID: os.Getpid() & 0xffff,
|
||||
Seq: rand.Intn(1 << 16),
|
||||
Data: []byte("wireguard ping"),
|
||||
}
|
||||
|
||||
msg := icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &requestPing,
|
||||
}
|
||||
|
||||
// Marshal the message
|
||||
icmpBytes, err := msg.Marshal(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ICMP message: %w", err)
|
||||
}
|
||||
|
||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
||||
// Set read deadline
|
||||
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
||||
return fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Send the ping
|
||||
start := time.Now()
|
||||
_, err = socket.Write(icmpBytes)
|
||||
_, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
n, err := socket.Read(icmpBytes[:])
|
||||
// Wait for reply
|
||||
reply := make([]byte, 1500)
|
||||
n, peer, err := conn.ReadFrom(reply)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
|
||||
// Parse reply
|
||||
replyMsg, err := icmp.ParseMessage(1, reply[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse ICMP packet: %w", err)
|
||||
return fmt.Errorf("failed to parse ICMP reply: %w", err)
|
||||
}
|
||||
|
||||
replyPing, ok := replyPacket.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
|
||||
// Verify reply
|
||||
switch replyMsg.Type {
|
||||
case ipv4.ICMPTypeEchoReply:
|
||||
replyEcho, ok := replyMsg.Body.(*icmp.Echo)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body)
|
||||
}
|
||||
if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq {
|
||||
return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d",
|
||||
replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg)
|
||||
}
|
||||
|
||||
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
|
||||
return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
|
||||
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
|
||||
}
|
||||
|
||||
logger.Info("Ping latency: %v", time.Since(start))
|
||||
duration := time.Since(start)
|
||||
logger.Info("Ping reply from %v: time=%v", peer, duration)
|
||||
return nil
|
||||
}
|
||||
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
||||
func startPingCheck(dev *device.Device, serverIP string, stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -122,10 +153,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := ping(tnet, serverIP)
|
||||
err := ping(dev, serverIP)
|
||||
if err != nil {
|
||||
logger.Warn("Periodic ping failed: %v", err)
|
||||
logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?")
|
||||
logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable")
|
||||
}
|
||||
case <-stopChan:
|
||||
logger.Info("Stopping ping check")
|
||||
@@ -135,7 +166,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
||||
}()
|
||||
}
|
||||
|
||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
func pingWithRetry(dev *device.Device, dst string) error {
|
||||
const (
|
||||
maxAttempts = 5
|
||||
retryDelay = 2 * time.Second
|
||||
@@ -145,7 +176,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
logger.Info("Ping attempt %d of %d", attempt, maxAttempts)
|
||||
|
||||
if err := ping(tnet, dst); err != nil {
|
||||
if err := ping(dev, dst); err != nil {
|
||||
lastErr = err
|
||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||
|
||||
@@ -161,7 +192,6 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This shouldn't be reached due to the return in the loop, but added for completeness
|
||||
return fmt.Errorf("unexpected error: all ping attempts failed")
|
||||
}
|
||||
|
||||
@@ -335,29 +365,13 @@ func main() {
|
||||
logger.Fatal("Failed to create client: %v", err)
|
||||
}
|
||||
|
||||
// Create WireGuard service
|
||||
wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||
}
|
||||
defer wgService.Close()
|
||||
|
||||
// Create TUN device and network stack
|
||||
var tun tun.Device
|
||||
var tnet *netstack.Net
|
||||
var dev *device.Device
|
||||
var pm *proxy.ProxyManager
|
||||
var connected bool
|
||||
var wgData WgData
|
||||
|
||||
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
|
||||
client.RegisterHandler("client/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
if pm != nil {
|
||||
pm.Stop()
|
||||
}
|
||||
if dev != nil {
|
||||
dev.Close()
|
||||
}
|
||||
client.Close()
|
||||
})
|
||||
|
||||
@@ -365,13 +379,12 @@ func main() {
|
||||
defer close(pingStopChan)
|
||||
|
||||
// Register handlers for different message types
|
||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||
client.RegisterHandler("client/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received registration message")
|
||||
|
||||
if connected {
|
||||
logger.Info("Already connected! But I will send a ping anyway...")
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
err := pingWithRetry(dev, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
@@ -391,17 +404,39 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Received: %+v", msg)
|
||||
tun, tnet, err = netstack.CreateNetTUN(
|
||||
[]netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
[]netip.Addr{netip.MustParseAddr(dns)},
|
||||
mtuInt)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
}
|
||||
// logger.Info("Received: %+v", msg)
|
||||
// tun, tnet, err = netstack.CreateNetTUN(
|
||||
// []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)},
|
||||
// []netip.Addr{netip.MustParseAddr(dns)},
|
||||
// mtuInt)
|
||||
// if err != nil {
|
||||
// logger.Error("Failed to create TUN device: %v", err)
|
||||
// }
|
||||
|
||||
tdev, err := func() (tun.Device, error) {
|
||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||
if tunFdStr == "" {
|
||||
return tun.CreateTUN(interfaceName, mtuInt)
|
||||
}
|
||||
|
||||
// construct tun device from supplied fd
|
||||
|
||||
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(int(fd), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "")
|
||||
return tun.CreateTUNFromFile(file, mtuInt)
|
||||
}()
|
||||
|
||||
// Create WireGuard device
|
||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
||||
dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger(
|
||||
mapToWireGuardLogLevel(loggerLevel),
|
||||
"wireguard: ",
|
||||
))
|
||||
@@ -433,7 +468,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||
// Ping to bring the tunnel up on the server side quickly
|
||||
// ping(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(tnet, wgData.ServerIP)
|
||||
err = pingWithRetry(dev, wgData.ServerIP)
|
||||
if err != nil {
|
||||
// Handle complete failure after all retries
|
||||
logger.Error("Failed to ping %s: %v", wgData.ServerIP, err)
|
||||
@@ -441,114 +476,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
|
||||
if !connected {
|
||||
logger.Info("Starting ping check")
|
||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
||||
startPingCheck(dev, wgData.ServerIP, pingStopChan)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
pm = proxy.NewProxyManager(tnet)
|
||||
|
||||
connected = true
|
||||
|
||||
// add the targets if there are any
|
||||
if len(wgData.Targets.TCP) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
||||
}
|
||||
|
||||
if len(wgData.Targets.UDP) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||
}
|
||||
|
||||
err = pm.Start()
|
||||
if err != nil {
|
||||
logger.Error("Failed to start proxy manager: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received: %+v", msg)
|
||||
|
||||
// if there is no wgData or pm, we can't add targets
|
||||
if wgData.TunnelIP == "" || pm == nil {
|
||||
logger.Info("No tunnel IP or proxy manager available")
|
||||
return
|
||||
}
|
||||
|
||||
targetData, err := parseTargetData(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error parsing target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
})
|
||||
|
||||
client.OnConnect(func() error {
|
||||
publicKey := privateKey.PublicKey()
|
||||
logger.Debug("Public key: %s", publicKey)
|
||||
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
err := client.SendMessage("client/wg/register", map[string]interface{}{
|
||||
"publicKey": fmt.Sprintf("%s", publicKey),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -574,62 +511,3 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(
|
||||
// Cleanup
|
||||
dev.Close()
|
||||
}
|
||||
|
||||
func parseTargetData(data interface{}) (TargetData, error) {
|
||||
var targetData TargetData
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return targetData, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return targetData, err
|
||||
}
|
||||
return targetData, nil
|
||||
}
|
||||
|
||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||
for _, t := range targetData.Targets {
|
||||
// Split the first number off of the target with : separator and use as the port
|
||||
parts := strings.Split(t, ":")
|
||||
if len(parts) != 3 {
|
||||
logger.Info("Invalid target format: %s", t)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the port as an int
|
||||
port := 0
|
||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
||||
if err != nil {
|
||||
logger.Info("Invalid port: %s", parts[0])
|
||||
continue
|
||||
}
|
||||
|
||||
if action == "add" {
|
||||
target := parts[1] + ":" + parts[2]
|
||||
// Only remove the specific target if it exists
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
// Ignore "target not found" errors as this is expected for new targets
|
||||
if !strings.Contains(err.Error(), "target not found") {
|
||||
logger.Error("Failed to remove existing target: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new target
|
||||
pm.AddTarget(proto, tunnelIP, port, target)
|
||||
|
||||
} else if action == "remove" {
|
||||
logger.Info("Removing target with port %d", port)
|
||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
||||
if err != nil {
|
||||
logger.Error("Failed to remove target: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user