diff --git a/main.go b/main.go index 942acef..63caa23 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "encoding/base64" "encoding/hex" "encoding/json" @@ -9,7 +8,6 @@ import ( "fmt" "math/rand" "net" - "net/netip" "os" "os/signal" "strconv" @@ -18,16 +16,15 @@ import ( "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" - "github.com/fosrl/newt/wg" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" + "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/tun/netstack" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -62,59 +59,93 @@ func fixKey(key string) string { return hex.EncodeToString(decoded) } -func ping(tnet *netstack.Net, dst string) error { - logger.Info("Pinging %s", dst) - socket, err := tnet.Dial("ping4", dst) +const ( + ENV_WG_TUN_FD = "WG_TUN_FD" + ENV_WG_UAPI_FD = "WG_UAPI_FD" + ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" +) + +func ping(dev *device.Device, dst string) error { + logger.Info("Pinging %s over WireGuard tunnel", dst) + + // Create a raw socket for ICMP + conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") if err != nil { return fmt.Errorf("failed to create ICMP socket: %w", err) } - defer socket.Close() + defer conn.Close() - requestPing := icmp.Echo{ - Seq: rand.Intn(1 << 16), - Data: []byte("gopher burrow"), + // Parse destination IP + dstIP := net.ParseIP(dst) + if dstIP == nil { + return fmt.Errorf("invalid destination IP: %s", dst) } - icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + // Create ICMP message + requestPing := icmp.Echo{ + ID: os.Getpid() & 0xffff, + Seq: rand.Intn(1 << 16), + Data: []byte("wireguard ping"), + } + + msg := icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &requestPing, + } + + // Marshal the message + icmpBytes, err := msg.Marshal(nil) if err != nil { return fmt.Errorf("failed to marshal ICMP message: %w", err) } - if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { + // Set read deadline + if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { return fmt.Errorf("failed to set read deadline: %w", err) } + // Send the ping start := time.Now() - _, err = socket.Write(icmpBytes) + _, err = conn.WriteTo(icmpBytes, &net.IPAddr{IP: dstIP}) if err != nil { return fmt.Errorf("failed to write ICMP packet: %w", err) } - n, err := socket.Read(icmpBytes[:]) + // Wait for reply + reply := make([]byte, 1500) + n, peer, err := conn.ReadFrom(reply) if err != nil { return fmt.Errorf("failed to read ICMP packet: %w", err) } - replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + // Parse reply + replyMsg, err := icmp.ParseMessage(1, reply[:n]) if err != nil { - return fmt.Errorf("failed to parse ICMP packet: %w", err) + return fmt.Errorf("failed to parse ICMP reply: %w", err) } - replyPing, ok := replyPacket.Body.(*icmp.Echo) - if !ok { - return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) + // Verify reply + switch replyMsg.Type { + case ipv4.ICMPTypeEchoReply: + replyEcho, ok := replyMsg.Body.(*icmp.Echo) + if !ok { + return fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyMsg.Body) + } + if replyEcho.ID != requestPing.ID || replyEcho.Seq != requestPing.Seq { + return fmt.Errorf("invalid echo reply: got id=%d seq=%d, want id=%d seq=%d", + replyEcho.ID, replyEcho.Seq, requestPing.ID, requestPing.Seq) + } + default: + return fmt.Errorf("unexpected ICMP message type: %+v", replyMsg) } - if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - return fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", - replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) - } - - logger.Info("Ping latency: %v", time.Since(start)) + duration := time.Since(start) + logger.Info("Ping reply from %v: time=%v", peer, duration) return nil } -func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) { +func startPingCheck(dev *device.Device, serverIP string, stopChan chan struct{}) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() @@ -122,10 +153,10 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) for { select { case <-ticker.C: - err := ping(tnet, serverIP) + err := ping(dev, serverIP) if err != nil { logger.Warn("Periodic ping failed: %v", err) - logger.Warn("HINT: Do you have UDP port 51280 (or the port in config.yml) open on your Pangolin server?") + logger.Warn("HINT: Check if the WireGuard tunnel is up and the server is reachable") } case <-stopChan: logger.Info("Stopping ping check") @@ -135,7 +166,7 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) }() } -func pingWithRetry(tnet *netstack.Net, dst string) error { +func pingWithRetry(dev *device.Device, dst string) error { const ( maxAttempts = 5 retryDelay = 2 * time.Second @@ -145,7 +176,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error { for attempt := 1; attempt <= maxAttempts; attempt++ { logger.Info("Ping attempt %d of %d", attempt, maxAttempts) - if err := ping(tnet, dst); err != nil { + if err := ping(dev, dst); err != nil { lastErr = err logger.Warn("Ping attempt %d failed: %v", attempt, err) @@ -161,7 +192,6 @@ func pingWithRetry(tnet *netstack.Net, dst string) error { return nil } - // This shouldn't be reached due to the return in the loop, but added for completeness return fmt.Errorf("unexpected error: all ping attempts failed") } @@ -335,29 +365,13 @@ func main() { logger.Fatal("Failed to create client: %v", err) } - // Create WireGuard service - wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) - if err != nil { - logger.Fatal("Failed to create WireGuard service: %v", err) - } - defer wgService.Close() - // Create TUN device and network stack - var tun tun.Device - var tnet *netstack.Net var dev *device.Device - var pm *proxy.ProxyManager var connected bool var wgData WgData - client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { + client.RegisterHandler("client/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") - if pm != nil { - pm.Stop() - } - if dev != nil { - dev.Close() - } client.Close() }) @@ -365,13 +379,12 @@ func main() { defer close(pingStopChan) // Register handlers for different message types - client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { + client.RegisterHandler("client/wg/connect", func(msg websocket.WSMessage) { logger.Info("Received registration message") if connected { logger.Info("Already connected! But I will send a ping anyway...") - // ping(tnet, wgData.ServerIP) - err = pingWithRetry(tnet, wgData.ServerIP) + err := pingWithRetry(dev, wgData.ServerIP) if err != nil { // Handle complete failure after all retries logger.Warn("Failed to ping %s: %v", wgData.ServerIP, err) @@ -391,17 +404,39 @@ func main() { return } - logger.Info("Received: %+v", msg) - tun, tnet, err = netstack.CreateNetTUN( - []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, - []netip.Addr{netip.MustParseAddr(dns)}, - mtuInt) - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - } + // logger.Info("Received: %+v", msg) + // tun, tnet, err = netstack.CreateNetTUN( + // []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, + // []netip.Addr{netip.MustParseAddr(dns)}, + // mtuInt) + // if err != nil { + // logger.Error("Failed to create TUN device: %v", err) + // } + + tdev, err := func() (tun.Device, error) { + tunFdStr := os.Getenv(ENV_WG_TUN_FD) + if tunFdStr == "" { + return tun.CreateTUN(interfaceName, mtuInt) + } + + // construct tun device from supplied fd + + fd, err := strconv.ParseUint(tunFdStr, 10, 32) + if err != nil { + return nil, err + } + + err = unix.SetNonblock(int(fd), true) + if err != nil { + return nil, err + } + + file := os.NewFile(uintptr(fd), "") + return tun.CreateTUNFromFile(file, mtuInt) + }() // Create WireGuard device - dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger( + dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger( mapToWireGuardLogLevel(loggerLevel), "wireguard: ", )) @@ -433,7 +468,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( logger.Info("WireGuard device created. Lets ping the server now...") // Ping to bring the tunnel up on the server side quickly // ping(tnet, wgData.ServerIP) - err = pingWithRetry(tnet, wgData.ServerIP) + err = pingWithRetry(dev, wgData.ServerIP) if err != nil { // Handle complete failure after all retries logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) @@ -441,114 +476,16 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( if !connected { logger.Info("Starting ping check") - startPingCheck(tnet, wgData.ServerIP, pingStopChan) + startPingCheck(dev, wgData.ServerIP, pingStopChan) } - - // Create proxy manager - pm = proxy.NewProxyManager(tnet) - connected = true - - // add the targets if there are any - if len(wgData.Targets.TCP) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP}) - } - - if len(wgData.Targets.UDP) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) - } - - err = pm.Start() - if err != nil { - logger.Error("Failed to start proxy manager: %v", err) - } - }) - - client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData) - } - }) - - client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData) - } - }) - - client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData) - } - }) - - client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) { - logger.Info("Received: %+v", msg) - - // if there is no wgData or pm, we can't add targets - if wgData.TunnelIP == "" || pm == nil { - logger.Info("No tunnel IP or proxy manager available") - return - } - - targetData, err := parseTargetData(msg.Data) - if err != nil { - logger.Info("Error parsing target data: %v", err) - return - } - - if len(targetData.Targets) > 0 { - updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData) - } }) client.OnConnect(func() error { publicKey := privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) - err := client.SendMessage("newt/wg/register", map[string]interface{}{ + err := client.SendMessage("client/wg/register", map[string]interface{}{ "publicKey": fmt.Sprintf("%s", publicKey), }) if err != nil { @@ -574,62 +511,3 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( // Cleanup dev.Close() } - -func parseTargetData(data interface{}) (TargetData, error) { - var targetData TargetData - jsonData, err := json.Marshal(data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return targetData, err - } - - if err := json.Unmarshal(jsonData, &targetData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return targetData, err - } - return targetData, nil -} - -func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - logger.Info("Invalid port: %s", parts[0]) - continue - } - - if action == "add" { - target := parts[1] + ":" + parts[2] - // Only remove the specific target if it exists - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - // Ignore "target not found" errors as this is expected for new targets - if !strings.Contains(err.Error(), "target not found") { - logger.Error("Failed to remove existing target: %v", err) - } - } - - // Add the new target - pm.AddTarget(proto, tunnelIP, port, target) - - } else if action == "remove" { - logger.Info("Removing target with port %d", port) - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - logger.Error("Failed to remove target: %v", err) - return err - } - } - } - - return nil -} diff --git a/proxy/manager.go b/proxy/manager.go deleted file mode 100644 index 0792acb..0000000 --- a/proxy/manager.go +++ /dev/null @@ -1,352 +0,0 @@ -package proxy - -import ( - "fmt" - "io" - "net" - "strings" - "sync" - "time" - - "github.com/fosrl/newt/logger" - "golang.zx2c4.com/wireguard/tun/netstack" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" -) - -// Target represents a proxy target with its address and port -type Target struct { - Address string - Port int -} - -// ProxyManager handles the creation and management of proxy connections -type ProxyManager struct { - tnet *netstack.Net - tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress - udpTargets map[string]map[int]string - listeners []*gonet.TCPListener - udpConns []*gonet.UDPConn - running bool - mutex sync.RWMutex -} - -// NewProxyManager creates a new proxy manager instance -func NewProxyManager(tnet *netstack.Net) *ProxyManager { - return &ProxyManager{ - tnet: tnet, - tcpTargets: make(map[string]map[int]string), - udpTargets: make(map[string]map[int]string), - listeners: make([]*gonet.TCPListener, 0), - udpConns: make([]*gonet.UDPConn, 0), - } -} - -// AddTarget adds as new target for proxying -func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - switch proto { - case "tcp": - if pm.tcpTargets[listenIP] == nil { - pm.tcpTargets[listenIP] = make(map[int]string) - } - pm.tcpTargets[listenIP][port] = targetAddr - case "udp": - if pm.udpTargets[listenIP] == nil { - pm.udpTargets[listenIP] = make(map[int]string) - } - pm.udpTargets[listenIP][port] = targetAddr - default: - return fmt.Errorf("unsupported protocol: %s", proto) - } - - if pm.running { - return pm.startTarget(proto, listenIP, port, targetAddr) - } else { - logger.Debug("Not adding target because not running") - } - return nil -} - -func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - switch proto { - case "tcp": - if targets, ok := pm.tcpTargets[listenIP]; ok { - delete(targets, port) - // Remove and close the corresponding TCP listener - for i, listener := range pm.listeners { - if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port { - listener.Close() - time.Sleep(50 * time.Millisecond) - // Remove from slice - pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...) - break - } - } - } else { - return fmt.Errorf("target not found: %s:%d", listenIP, port) - } - case "udp": - if targets, ok := pm.udpTargets[listenIP]; ok { - delete(targets, port) - // Remove and close the corresponding UDP connection - for i, conn := range pm.udpConns { - if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port { - conn.Close() - time.Sleep(50 * time.Millisecond) - // Remove from slice - pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) - break - } - } - } else { - return fmt.Errorf("target not found: %s:%d", listenIP, port) - } - default: - return fmt.Errorf("unsupported protocol: %s", proto) - } - return nil -} - -// Start begins listening for all configured proxy targets -func (pm *ProxyManager) Start() error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - if pm.running { - return nil - } - - // Start TCP targets - for listenIP, targets := range pm.tcpTargets { - for port, targetAddr := range targets { - if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil { - return fmt.Errorf("failed to start TCP target: %v", err) - } - } - } - - // Start UDP targets - for listenIP, targets := range pm.udpTargets { - for port, targetAddr := range targets { - if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil { - return fmt.Errorf("failed to start UDP target: %v", err) - } - } - } - - pm.running = true - return nil -} - -func (pm *ProxyManager) Stop() error { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - if !pm.running { - return nil - } - - // Set running to false first to signal handlers to stop - pm.running = false - - // Close TCP listeners - for i := len(pm.listeners) - 1; i >= 0; i-- { - listener := pm.listeners[i] - if err := listener.Close(); err != nil { - logger.Error("Error closing TCP listener: %v", err) - } - // Remove from slice - pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...) - } - - // Close UDP connections - for i := len(pm.udpConns) - 1; i >= 0; i-- { - conn := pm.udpConns[i] - if err := conn.Close(); err != nil { - logger.Error("Error closing UDP connection: %v", err) - } - // Remove from slice - pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) - } - - // Clear the target maps - for k := range pm.tcpTargets { - delete(pm.tcpTargets, k) - } - for k := range pm.udpTargets { - delete(pm.udpTargets, k) - } - - // Give active connections a chance to close gracefully - time.Sleep(100 * time.Millisecond) - - return nil -} - -func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error { - switch proto { - case "tcp": - listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port}) - if err != nil { - return fmt.Errorf("failed to create TCP listener: %v", err) - } - - pm.listeners = append(pm.listeners, listener) - go pm.handleTCPProxy(listener, targetAddr) - - case "udp": - addr := &net.UDPAddr{Port: port} - conn, err := pm.tnet.ListenUDP(addr) - if err != nil { - return fmt.Errorf("failed to create UDP listener: %v", err) - } - - pm.udpConns = append(pm.udpConns, conn) - go pm.handleUDPProxy(conn, targetAddr) - - default: - return fmt.Errorf("unsupported protocol: %s", proto) - } - - logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr) - - return nil -} - -func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) { - for { - conn, err := listener.Accept() - if err != nil { - // Check if we're shutting down or the listener was closed - if !pm.running { - return - } - - // Check for specific network errors that indicate the listener is closed - if ne, ok := err.(net.Error); ok && !ne.Temporary() { - logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) - return - } - - logger.Error("Error accepting TCP connection: %v", err) - // Don't hammer the CPU if we hit a temporary error - time.Sleep(100 * time.Millisecond) - continue - } - - go func() { - target, err := net.Dial("tcp", targetAddr) - if err != nil { - logger.Error("Error connecting to target: %v", err) - conn.Close() - return - } - - // Create a WaitGroup to ensure both copy operations complete - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - io.Copy(target, conn) - target.Close() - }() - - go func() { - defer wg.Done() - io.Copy(conn, target) - conn.Close() - }() - - // Wait for both copies to complete - wg.Wait() - }() - } -} - -func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { - buffer := make([]byte, 65507) // Max UDP packet size - clientConns := make(map[string]*net.UDPConn) - var clientsMutex sync.RWMutex - - for { - n, remoteAddr, err := conn.ReadFrom(buffer) - if err != nil { - if !pm.running { - return - } - - // Check for connection closed conditions - if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { - logger.Info("UDP connection closed, stopping proxy handler") - - // Clean up existing client connections - clientsMutex.Lock() - for _, targetConn := range clientConns { - targetConn.Close() - } - clientConns = nil - clientsMutex.Unlock() - - return - } - - logger.Error("Error reading UDP packet: %v", err) - continue - } - - clientKey := remoteAddr.String() - clientsMutex.RLock() - targetConn, exists := clientConns[clientKey] - clientsMutex.RUnlock() - - if !exists { - targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) - if err != nil { - logger.Error("Error resolving target address: %v", err) - continue - } - - targetConn, err = net.DialUDP("udp", nil, targetUDPAddr) - if err != nil { - logger.Error("Error connecting to target: %v", err) - continue - } - - clientsMutex.Lock() - clientConns[clientKey] = targetConn - clientsMutex.Unlock() - - go func() { - buffer := make([]byte, 65507) - for { - n, _, err := targetConn.ReadFromUDP(buffer) - if err != nil { - logger.Error("Error reading from target: %v", err) - return - } - - _, err = conn.WriteTo(buffer[:n], remoteAddr) - if err != nil { - logger.Error("Error writing to client: %v", err) - return - } - } - }() - } - - _, err = targetConn.Write(buffer[:n]) - if err != nil { - logger.Error("Error writing to target: %v", err) - targetConn.Close() - clientsMutex.Lock() - delete(clientConns, clientKey) - clientsMutex.Unlock() - } - } -} diff --git a/wg/wg.go b/wg/wg.go deleted file mode 100644 index 4699ed7..0000000 --- a/wg/wg.go +++ /dev/null @@ -1,606 +0,0 @@ -package wg - -import ( - "bytes" - "encoding/json" - "fmt" - "net" - "os" - "sync" - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/websocket" - "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -var ( - interfaceName string - listenAddr string - mtuInt int - lastReadings = make(map[string]PeerReading) - mu sync.Mutex -) - -type WgConfig struct { - PrivateKey string `json:"privateKey"` - ListenPort int `json:"listenPort"` - IpAddress string `json:"ipAddress"` - Peers []Peer `json:"peers"` -} - -type Peer struct { - PublicKey string `json:"publicKey"` - AllowedIPs []string `json:"allowedIps"` -} - -type PeerBandwidth struct { - PublicKey string `json:"publicKey"` - BytesIn float64 `json:"bytesIn"` - BytesOut float64 `json:"bytesOut"` -} - -type PeerReading struct { - BytesReceived int64 - BytesTransmitted int64 - LastChecked time.Time -} - -var ( - wgClient *wgctrl.Client -) - -type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - wgClient *wgctrl.Client - config WgConfig - key wgtypes.Key - reachableAt string - lastReadings map[string]PeerReading - mu sync.Mutex -} - -func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { - wgClient, err := wgctrl.New() - if err != nil { - return nil, fmt.Errorf("failed to create WireGuard client: %v", err) - } - - key := wgtypes.Key{} - // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - // generate a new private key - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - // save the key to the file - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) - if err != nil { - logger.Fatal("Failed to save private key: %v", err) - } - } else { - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - logger.Fatal("Failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(string(keyData)) - if err != nil { - logger.Fatal("Failed to parse private key: %v", err) - } - } - - service := &WireGuardService{ - interfaceName: interfaceName, - mtu: mtu, - client: wsClient, - wgClient: wgClient, - key: key, - reachableAt: reachableAt, - lastReadings: make(map[string]PeerReading), - } - - // Register websocket handlers - wsClient.RegisterHandler("wg/config/receive", service.handleConfig) - wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer) - wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer) - - // Register connect handler to initiate configuration - wsClient.OnConnect(service.loadRemoteConfig) - - return service, nil -} - -func (s *WireGuardService) Close() { - s.client.Close() - wgClient.Close() -} - -func (s *WireGuardService) loadRemoteConfig() error { - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) - - go s.periodicBandwidthCheck() - - err := s.client.SendMessage("wg/config/get", body) - if err != nil { - return fmt.Errorf("failed to send config request: %v", err) - } - - return nil -} - -func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { - var config WgConfig - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - } - - if err := json.Unmarshal(jsonData, &config); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - } - - s.config = config - - // Ensure the WireGuard interface and peers are configured - if err := s.ensureWireguardInterface(config); err != nil { - logger.Error("Failed to ensure WireGuard interface: %v", err) - } - - if err := s.ensureWireguardPeers(config.Peers); err != nil { - logger.Error("Failed to ensure WireGuard peers: %v", err) - } -} - -func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { - // Check if the WireGuard interface exists - _, err := netlink.LinkByName(interfaceName) - if err != nil { - if _, ok := err.(netlink.LinkNotFoundError); ok { - // Interface doesn't exist, so create it - err = s.createWireGuardInterface() - if err != nil { - logger.Fatal("Failed to create WireGuard interface: %v", err) - } - logger.Info("Created WireGuard interface %s\n", interfaceName) - } else { - logger.Fatal("Error checking for WireGuard interface: %v", err) - } - } else { - logger.Info("WireGuard interface %s already exists\n", interfaceName) - return nil - } - - // Assign IP address to the interface - err = s.assignIPAddress(wgconfig.IpAddress) - if err != nil { - logger.Fatal("Failed to assign IP address: %v", err) - } - logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) - - // Check if the interface already exists - _, err = wgClient.Device(interfaceName) - if err != nil { - return fmt.Errorf("interface %s does not exist", interfaceName) - } - - // Parse the private key - key, err := wgtypes.ParseKey(wgconfig.PrivateKey) - if err != nil { - return fmt.Errorf("failed to parse private key: %v", err) - } - - // Create a new WireGuard configuration - config := wgtypes.Config{ - PrivateKey: &key, - ListenPort: new(int), - } - *config.ListenPort = wgconfig.ListenPort - - // Create and configure the WireGuard interface - err = wgClient.ConfigureDevice(interfaceName, config) - if err != nil { - return fmt.Errorf("failed to configure WireGuard device: %v", err) - } - - // bring up the interface - link, err := netlink.LinkByName(interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface: %v", err) - } - - if err := netlink.LinkSetMTU(link, mtuInt); err != nil { - return fmt.Errorf("failed to set MTU: %v", err) - } - - if err := netlink.LinkSetUp(link); err != nil { - return fmt.Errorf("failed to bring up interface: %v", err) - } - - // if err := s.ensureMSSClamping(); err != nil { - // logger.Warn("Failed to ensure MSS clamping: %v", err) - // } - - logger.Info("WireGuard interface %s created and configured", interfaceName) - - return nil -} - -func (s *WireGuardService) createWireGuardInterface() error { - wgLink := &netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, - LinkType: "wireguard", - } - return netlink.LinkAdd(wgLink) -} - -func (s *WireGuardService) assignIPAddress(ipAddress string) error { - link, err := netlink.LinkByName(interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface: %v", err) - } - - addr, err := netlink.ParseAddr(ipAddress) - if err != nil { - return fmt.Errorf("failed to parse IP address: %v", err) - } - - return netlink.AddrAdd(link, addr) -} - -func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { - // get the current peers - device, err := wgClient.Device(interfaceName) - if err != nil { - return fmt.Errorf("failed to get device: %v", err) - } - - // get the peer public keys - var currentPeers []string - for _, peer := range device.Peers { - currentPeers = append(currentPeers, peer.PublicKey.String()) - } - - // remove any peers that are not in the config - for _, peer := range currentPeers { - found := false - for _, configPeer := range peers { - if peer == configPeer.PublicKey { - found = true - break - } - } - if !found { - err := s.removePeer(peer) - if err != nil { - return fmt.Errorf("failed to remove peer: %v", err) - } - } - } - - // add any peers that are in the config but not in the current peers - for _, configPeer := range peers { - found := false - for _, peer := range currentPeers { - if configPeer.PublicKey == peer { - found = true - break - } - } - if !found { - err := s.addPeer(configPeer) - if err != nil { - return fmt.Errorf("failed to add peer: %v", err) - } - } - } - - return nil -} - -// func (s *WireGuardService) ensureMSSClamping() error { -// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) -// mssValue := mtuInt - 40 - -// // Rules to be managed - just the chains, we'll construct the full command separately -// chains := []string{"INPUT", "OUTPUT", "FORWARD"} - -// // First, try to delete any existing rules -// for _, chain := range chains { -// deleteCmd := exec.Command("/usr/sbin/iptables", -// "-t", "mangle", -// "-D", chain, -// "-p", "tcp", -// "--tcp-flags", "SYN,RST", "SYN", -// "-j", "TCPMSS", -// "--set-mss", fmt.Sprintf("%d", mssValue)) - -// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) - -// // Try deletion multiple times to handle multiple existing rules -// for i := 0; i < 3; i++ { -// out, err := deleteCmd.CombinedOutput() -// if err != nil { -// // Convert exit status 1 to string for better logging -// if exitErr, ok := err.(*exec.ExitError); ok { -// logger.Debug("Deletion stopped for chain %s: %v (output: %s)", -// chain, exitErr.String(), string(out)) -// } -// break // No more rules to delete -// } -// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) -// } -// } - -// // Then add the new rules -// var errors []error -// for _, chain := range chains { -// addCmd := exec.Command("/usr/sbin/iptables", -// "-t", "mangle", -// "-A", chain, -// "-p", "tcp", -// "--tcp-flags", "SYN,RST", "SYN", -// "-j", "TCPMSS", -// "--set-mss", fmt.Sprintf("%d", mssValue)) - -// logger.Info("Adding MSS clamping rule for chain %s", chain) - -// if out, err := addCmd.CombinedOutput(); err != nil { -// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", -// chain, err, string(out)) -// logger.Error(errMsg) -// errors = append(errors, fmt.Errorf(errMsg)) -// continue -// } - -// // Verify the rule was added -// checkCmd := exec.Command("/usr/sbin/iptables", -// "-t", "mangle", -// "-C", chain, -// "-p", "tcp", -// "--tcp-flags", "SYN,RST", "SYN", -// "-j", "TCPMSS", -// "--set-mss", fmt.Sprintf("%d", mssValue)) - -// if out, err := checkCmd.CombinedOutput(); err != nil { -// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", -// chain, err, string(out)) -// logger.Error(errMsg) -// errors = append(errors, fmt.Errorf(errMsg)) -// continue -// } - -// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) -// } - -// // If we encountered any errors, return them combined -// if len(errors) > 0 { -// var errMsgs []string -// for _, err := range errors { -// errMsgs = append(errMsgs, err.Error()) -// } -// return fmt.Errorf("MSS clamping setup encountered errors:\n%s", -// strings.Join(errMsgs, "\n")) -// } - -// return nil -// } - -func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { - var peer Peer - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - } - - if err := json.Unmarshal(jsonData, &peer); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - } - - err = s.addPeer(peer) - if err != nil { - return - } -} - -func (s *WireGuardService) addPeer(peer Peer) error { - pubKey, err := wgtypes.ParseKey(peer.PublicKey) - if err != nil { - return fmt.Errorf("failed to parse public key: %v", err) - } - - // parse allowed IPs into array of net.IPNet - var allowedIPs []net.IPNet - for _, ipStr := range peer.AllowedIPs { - _, ipNet, err := net.ParseCIDR(ipStr) - if err != nil { - return fmt.Errorf("failed to parse allowed IP: %v", err) - } - allowedIPs = append(allowedIPs, *ipNet) - } - - peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - AllowedIPs: allowedIPs, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peerConfig}, - } - - if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { - return fmt.Errorf("failed to add peer: %v", err) - } - - logger.Info("Peer %s added successfully", peer.PublicKey) - - return nil -} - -func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { - // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } - type RemoveRequest struct { - PublicKey string `json:"publicKey"` - } - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - } - - var request RemoveRequest - if err := json.Unmarshal(jsonData, &request); err != nil { - logger.Info("Error unmarshaling data: %v", err) - return - } - - if err := s.removePeer(request.PublicKey); err != nil { - logger.Info("Error removing peer: %v", err) - return - } -} - -func (s *WireGuardService) removePeer(publicKey string) error { - pubKey, err := wgtypes.ParseKey(publicKey) - if err != nil { - return fmt.Errorf("failed to parse public key: %v", err) - } - - peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - Remove: true, - } - - config := wgtypes.Config{ - Peers: []wgtypes.PeerConfig{peerConfig}, - } - - if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { - return fmt.Errorf("failed to remove peer: %v", err) - } - - logger.Info("Peer %s removed successfully", publicKey) - - return nil -} - -func (s *WireGuardService) periodicBandwidthCheck() { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for range ticker.C { - if err := s.reportPeerBandwidth(); err != nil { - logger.Info("Failed to report peer bandwidth: %v", err) - } - } -} - -func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { - device, err := wgClient.Device(interfaceName) - if err != nil { - return nil, fmt.Errorf("failed to get device: %v", err) - } - - peerBandwidths := []PeerBandwidth{} - now := time.Now() - - mu.Lock() - defer mu.Unlock() - - for _, peer := range device.Peers { - publicKey := peer.PublicKey.String() - currentReading := PeerReading{ - BytesReceived: peer.ReceiveBytes, - BytesTransmitted: peer.TransmitBytes, - LastChecked: now, - } - - var bytesInDiff, bytesOutDiff float64 - lastReading, exists := lastReadings[publicKey] - - if exists { - timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() - if timeDiff > 0 { - // Calculate bytes transferred since last reading - bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) - bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) - - // Handle counter wraparound (if the counter resets or overflows) - if bytesInDiff < 0 { - bytesInDiff = float64(currentReading.BytesReceived) - } - if bytesOutDiff < 0 { - bytesOutDiff = float64(currentReading.BytesTransmitted) - } - - // Convert to MB - bytesInMB := bytesInDiff / (1024 * 1024) - bytesOutMB := bytesOutDiff / (1024 * 1024) - - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: bytesInMB, - BytesOut: bytesOutMB, - }) - } else { - // If readings are too close together or time hasn't passed, report 0 - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - }) - } - } else { - // For first reading of a peer, report 0 to establish baseline - peerBandwidths = append(peerBandwidths, PeerBandwidth{ - PublicKey: publicKey, - BytesIn: 0, - BytesOut: 0, - }) - } - - // Update the last reading - lastReadings[publicKey] = currentReading - } - - // Clean up old peers - for publicKey := range lastReadings { - found := false - for _, peer := range device.Peers { - if peer.PublicKey.String() == publicKey { - found = true - break - } - } - if !found { - delete(lastReadings, publicKey) - } - } - - return peerBandwidths, nil -} - -func (s *WireGuardService) reportPeerBandwidth() error { - bandwidths, err := s.calculatePeerBandwidth() - if err != nil { - return fmt.Errorf("failed to calculate peer bandwidth: %v", err) - } - - jsonData, err := json.Marshal(bandwidths) - if err != nil { - return fmt.Errorf("failed to marshal bandwidth data: %v", err) - } - - err = s.client.SendMessage("wg/bandwidth", jsonData) - if err != nil { - return fmt.Errorf("failed to send bandwidth data: %v", err) - } - - return nil -}