From e8bd55bed9d9cf31b2c1b21322348ee6132c8621 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 20 Feb 2025 20:04:01 -0500 Subject: [PATCH 01/19] Copy in gerbil wg config --- go.mod | 16 +- go.sum | 26 +++ wg/wg.go | 646 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 684 insertions(+), 4 deletions(-) create mode 100644 wg/wg.go diff --git a/go.mod b/go.mod index 2cc0c19..7812b1b 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,21 @@ require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 require ( github.com/google/btree v1.1.2 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect - golang.org/x/crypto v0.28.0 // indirect + github.com/josharian/native v1.1.0 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect + github.com/vishvananda/netlink v1.3.0 // indirect + github.com/vishvananda/netns v0.0.4 // indirect + golang.org/x/crypto v0.31.0 // indirect golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/sys v0.26.0 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 // indirect + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 // indirect gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) diff --git a/go.sum b/go.sum index d95ab3a..f453d4f 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,39 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= @@ -18,5 +42,7 @@ golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uI golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= +golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/wg/wg.go b/wg/wg.go new file mode 100644 index 0000000..ed868ea --- /dev/null +++ b/wg/wg.go @@ -0,0 +1,646 @@ +package wg + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "strconv" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/websocket" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var ( + interfaceName string + listenAddr string + mtuInt int + lastReadings = make(map[string]PeerReading) + mu sync.Mutex +) + +type WgConfig struct { + PrivateKey string `json:"privateKey"` + ListenPort int `json:"listenPort"` + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` +} + +type Peer struct { + PublicKey string `json:"publicKey"` + AllowedIPs []string `json:"allowedIps"` +} + +type PeerBandwidth struct { + PublicKey string `json:"publicKey"` + BytesIn float64 `json:"bytesIn"` + BytesOut float64 `json:"bytesOut"` +} + +type PeerReading struct { + BytesReceived int64 + BytesTransmitted int64 + LastChecked time.Time +} + +var ( + wgClient *wgctrl.Client +) + +func main() { + var ( + err error + wgconfig WgConfig + remoteConfigURL string + generateAndSaveKeyTo string + reachableAt string + logLevel string + mtu string + ) + + interfaceName = os.Getenv("INTERFACE") + remoteConfigURL = os.Getenv("REMOTE_CONFIG") + listenAddr = os.Getenv("LISTEN") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + reachableAt = os.Getenv("REACHABLE_AT") + logLevel = os.Getenv("LOG_LEVEL") + mtu = os.Getenv("MTU") + + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") + } + if remoteConfigURL == "" { + flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") + } + if reachableAt == "" { + flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") + } + if logLevel == "" { + flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") + } + if mtu == "" { + flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface") + } + flag.Parse() + + mtuInt, err = strconv.Atoi(mtu) + if err != nil { + logger.Fatal("Failed to parse MTU: %v", err) + } + + var key wgtypes.Key + // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file + if generateAndSaveKeyTo != "" { + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + // generate a new private key + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Fatal("Failed to generate private key: %v", err) + } + // save the key to the file + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + logger.Fatal("Failed to save private key: %v", err) + } + } else { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + logger.Fatal("Failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(string(keyData)) + if err != nil { + logger.Fatal("Failed to parse private key: %v", err) + } + } + } else { + // if no generateAndSaveKeyTo is provided, ensure that the private key is provided + if wgconfig.PrivateKey == "" { + // generate a new one + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Fatal("Failed to generate private key: %v", err) + } + } + } + + // loop until we get the config + for wgconfig.PrivateKey == "" { + logger.Info("Fetching remote config from %s", remoteConfigURL) + wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt) + if err != nil { + logger.Error("Failed to load configuration: %v", err) + time.Sleep(5 * time.Second) + continue + } + wgconfig.PrivateKey = key.String() + } + + wgClient, err = wgctrl.New() + if err != nil { + logger.Fatal("Failed to create WireGuard client: %v", err) + } + defer wgClient.Close() + + // Ensure the WireGuard interface exists and is configured + if err := ensureWireguardInterface(wgconfig); err != nil { + logger.Fatal("Failed to ensure WireGuard interface: %v", err) + } + + // Ensure the WireGuard peers exist + ensureWireguardPeers(wgconfig.Peers) + + // go periodicBandwidthCheck(reportBandwidthTo) +} + +func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { + var body *bytes.Buffer + if reachableAt == "" { + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String()))) + } else { + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) + } + resp, err := http.Post(url, "application/json", body) + if err != nil { + // print the error + logger.Error("Error fetching remote config %s: %v", url, err) + return WgConfig{}, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return WgConfig{}, err + } + + var config WgConfig + err = json.Unmarshal(data, &config) + + return config, err +} + +func ensureWireguardInterface(wgconfig WgConfig) error { + // Check if the WireGuard interface exists + _, err := netlink.LinkByName(interfaceName) + if err != nil { + if _, ok := err.(netlink.LinkNotFoundError); ok { + // Interface doesn't exist, so create it + err = createWireGuardInterface() + if err != nil { + logger.Fatal("Failed to create WireGuard interface: %v", err) + } + logger.Info("Created WireGuard interface %s\n", interfaceName) + } else { + logger.Fatal("Error checking for WireGuard interface: %v", err) + } + } else { + logger.Info("WireGuard interface %s already exists\n", interfaceName) + return nil + } + + // Assign IP address to the interface + err = assignIPAddress(wgconfig.IpAddress) + if err != nil { + logger.Fatal("Failed to assign IP address: %v", err) + } + logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) + + // Check if the interface already exists + _, err = wgClient.Device(interfaceName) + if err != nil { + return fmt.Errorf("interface %s does not exist", interfaceName) + } + + // Parse the private key + key, err := wgtypes.ParseKey(wgconfig.PrivateKey) + if err != nil { + return fmt.Errorf("failed to parse private key: %v", err) + } + + // Create a new WireGuard configuration + config := wgtypes.Config{ + PrivateKey: &key, + ListenPort: new(int), + } + *config.ListenPort = wgconfig.ListenPort + + // Create and configure the WireGuard interface + err = wgClient.ConfigureDevice(interfaceName, config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // bring up the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + if err := netlink.LinkSetMTU(link, mtuInt); err != nil { + return fmt.Errorf("failed to set MTU: %v", err) + } + + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + if err := ensureMSSClamping(); err != nil { + logger.Warn("Failed to ensure MSS clamping: %v", err) + } + + logger.Info("WireGuard interface %s created and configured", interfaceName) + + return nil +} + +func createWireGuardInterface() error { + wgLink := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, + LinkType: "wireguard", + } + return netlink.LinkAdd(wgLink) +} + +func assignIPAddress(ipAddress string) error { + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface: %v", err) + } + + addr, err := netlink.ParseAddr(ipAddress) + if err != nil { + return fmt.Errorf("failed to parse IP address: %v", err) + } + + return netlink.AddrAdd(link, addr) +} + +func ensureWireguardPeers(peers []Peer) error { + // get the current peers + device, err := wgClient.Device(interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the peer public keys + var currentPeers []string + for _, peer := range device.Peers { + currentPeers = append(currentPeers, peer.PublicKey.String()) + } + + // remove any peers that are not in the config + for _, peer := range currentPeers { + found := false + for _, configPeer := range peers { + if peer == configPeer.PublicKey { + found = true + break + } + } + if !found { + err := removePeer(peer) + if err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + } + } + + // add any peers that are in the config but not in the current peers + for _, configPeer := range peers { + found := false + for _, peer := range currentPeers { + if configPeer.PublicKey == peer { + found = true + break + } + } + if !found { + err := addPeer(configPeer) + if err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + } + } + + return nil +} + +func ensureMSSClamping() error { + // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) + mssValue := mtuInt - 40 + + // Rules to be managed - just the chains, we'll construct the full command separately + chains := []string{"INPUT", "OUTPUT", "FORWARD"} + + // First, try to delete any existing rules + for _, chain := range chains { + deleteCmd := exec.Command("/usr/sbin/iptables", + "-t", "mangle", + "-D", chain, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mssValue)) + + logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) + + // Try deletion multiple times to handle multiple existing rules + for i := 0; i < 3; i++ { + out, err := deleteCmd.CombinedOutput() + if err != nil { + // Convert exit status 1 to string for better logging + if exitErr, ok := err.(*exec.ExitError); ok { + logger.Debug("Deletion stopped for chain %s: %v (output: %s)", + chain, exitErr.String(), string(out)) + } + break // No more rules to delete + } + logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) + } + } + + // Then add the new rules + var errors []error + for _, chain := range chains { + addCmd := exec.Command("/usr/sbin/iptables", + "-t", "mangle", + "-A", chain, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mssValue)) + + logger.Info("Adding MSS clamping rule for chain %s", chain) + + if out, err := addCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", + chain, err, string(out)) + logger.Error(errMsg) + errors = append(errors, fmt.Errorf(errMsg)) + continue + } + + // Verify the rule was added + checkCmd := exec.Command("/usr/sbin/iptables", + "-t", "mangle", + "-C", chain, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mssValue)) + + if out, err := checkCmd.CombinedOutput(); err != nil { + errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", + chain, err, string(out)) + logger.Error(errMsg) + errors = append(errors, fmt.Errorf(errMsg)) + continue + } + + logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) + } + + // If we encountered any errors, return them combined + if len(errors) > 0 { + var errMsgs []string + for _, err := range errors { + errMsgs = append(errMsgs, err.Error()) + } + return fmt.Errorf("MSS clamping setup encountered errors:\n%s", + strings.Join(errMsgs, "\n")) + } + + return nil +} + +func handleAddPeer(msg websocket.WSMessage) { + var peer Peer + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + if err := json.Unmarshal(jsonData, &peer); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + } + + err = addPeer(peer) + if err != nil { + return + } +} + +func addPeer(peer Peer) error { + pubKey, err := wgtypes.ParseKey(peer.PublicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + // parse allowed IPs into array of net.IPNet + var allowedIPs []net.IPNet + for _, ipStr := range peer.AllowedIPs { + _, ipNet, err := net.ParseCIDR(ipStr) + if err != nil { + return fmt.Errorf("failed to parse allowed IP: %v", err) + } + allowedIPs = append(allowedIPs, *ipNet) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + return fmt.Errorf("failed to add peer: %v", err) + } + + logger.Info("Peer %s added successfully", peer.PublicKey) + + return nil +} + +func handleRemovePeer(msg websocket.WSMessage) { + // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } + type RemoveRequest struct { + PublicKey string `json:"publicKey"` + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + var request RemoveRequest + if err := json.Unmarshal(jsonData, &request); err != nil { + logger.Info("Error unmarshaling data: %v", err) + return + } + + if err := removePeer(request.PublicKey); err != nil { + logger.Info("Error removing peer: %v", err) + return + } +} + +func removePeer(publicKey string) error { + pubKey, err := wgtypes.ParseKey(publicKey) + if err != nil { + return fmt.Errorf("failed to parse public key: %v", err) + } + + peerConfig := wgtypes.PeerConfig{ + PublicKey: pubKey, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peerConfig}, + } + + if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + return fmt.Errorf("failed to remove peer: %v", err) + } + + logger.Info("Peer %s removed successfully", publicKey) + + return nil +} + +func periodicBandwidthCheck(endpoint string) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if err := reportPeerBandwidth(endpoint); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + } +} + +func calculatePeerBandwidth() ([]PeerBandwidth, error) { + device, err := wgClient.Device(interfaceName) + if err != nil { + return nil, fmt.Errorf("failed to get device: %v", err) + } + + peerBandwidths := []PeerBandwidth{} + now := time.Now() + + mu.Lock() + defer mu.Unlock() + + for _, peer := range device.Peers { + publicKey := peer.PublicKey.String() + currentReading := PeerReading{ + BytesReceived: peer.ReceiveBytes, + BytesTransmitted: peer.TransmitBytes, + LastChecked: now, + } + + var bytesInDiff, bytesOutDiff float64 + lastReading, exists := lastReadings[publicKey] + + if exists { + timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() + if timeDiff > 0 { + // Calculate bytes transferred since last reading + bytesInDiff = float64(currentReading.BytesReceived - lastReading.BytesReceived) + bytesOutDiff = float64(currentReading.BytesTransmitted - lastReading.BytesTransmitted) + + // Handle counter wraparound (if the counter resets or overflows) + if bytesInDiff < 0 { + bytesInDiff = float64(currentReading.BytesReceived) + } + if bytesOutDiff < 0 { + bytesOutDiff = float64(currentReading.BytesTransmitted) + } + + // Convert to MB + bytesInMB := bytesInDiff / (1024 * 1024) + bytesOutMB := bytesOutDiff / (1024 * 1024) + + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: bytesInMB, + BytesOut: bytesOutMB, + }) + } else { + // If readings are too close together or time hasn't passed, report 0 + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + } else { + // For first reading of a peer, report 0 to establish baseline + peerBandwidths = append(peerBandwidths, PeerBandwidth{ + PublicKey: publicKey, + BytesIn: 0, + BytesOut: 0, + }) + } + + // Update the last reading + lastReadings[publicKey] = currentReading + } + + // Clean up old peers + for publicKey := range lastReadings { + found := false + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + found = true + break + } + } + if !found { + delete(lastReadings, publicKey) + } + } + + return peerBandwidths, nil +} + +func reportPeerBandwidth(apiURL string) error { + bandwidths, err := calculatePeerBandwidth() + if err != nil { + return fmt.Errorf("failed to calculate peer bandwidth: %v", err) + } + + jsonData, err := json.Marshal(bandwidths) + if err != nil { + return fmt.Errorf("failed to marshal bandwidth data: %v", err) + } + + resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to send bandwidth data: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("API returned non-OK status: %s", resp.Status) + } + + return nil +} From f69a7f647d8cbfcafc8ee72c4600d9a89c8b4b76 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 20 Feb 2025 20:37:31 -0500 Subject: [PATCH 02/19] Move wg into more of a class --- main.go | 49 +++++++-- wg/wg.go | 295 +++++++++++++++++++++++++++---------------------------- 2 files changed, 184 insertions(+), 160 deletions(-) diff --git a/main.go b/main.go index 786ecbd..9cee671 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" + "github.com/fosrl/newt/wg" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" @@ -246,15 +247,18 @@ func resolveDomain(domain string) (string, error) { func main() { var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string + endpoint string + id string + secret string + mtu string + mtuInt int + dns string + privateKey wgtypes.Key + err error + logLevel string + interfaceName string + generateAndSaveKeyTo string + reachableAt string ) // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values @@ -264,6 +268,9 @@ func main() { mtu = os.Getenv("MTU") dns = os.Getenv("DNS") logLevel = os.Getenv("LOG_LEVEL") + interfaceName = os.Getenv("INTERFACE") + generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") + reachableAt = os.Getenv("REACHABLE_AT") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -283,6 +290,15 @@ func main() { if logLevel == "" { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } + if interfaceName == "" { + flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface") + } + if generateAndSaveKeyTo == "" { + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") + } + if reachableAt == "" { + flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -319,6 +335,21 @@ func main() { logger.Fatal("Failed to create client: %v", err) } + // Create WireGuard service + wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + // defer wgService.Close() + + // Start the WireGuard service + if err := wgService.Start(); err != nil { + logger.Fatal("Failed to start WireGuard service: %v", err) + } + + // Start bandwidth reporting + wgService.StartBandwidthReporting() + // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net diff --git a/wg/wg.go b/wg/wg.go index ed868ea..dc5e337 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -3,14 +3,10 @@ package wg import ( "bytes" "encoding/json" - "flag" "fmt" - "io" "net" - "net/http" "os" "os/exec" - "strconv" "strings" "sync" "time" @@ -58,147 +54,149 @@ var ( wgClient *wgctrl.Client ) -func main() { - var ( - err error - wgconfig WgConfig - remoteConfigURL string - generateAndSaveKeyTo string - reachableAt string - logLevel string - mtu string - ) +type WireGuardService struct { + interfaceName string + mtu int + client *websocket.Client + wgClient *wgctrl.Client + config WgConfig + key wgtypes.Key + reachableAt string + generateAndSaveKeyTo string + lastReadings map[string]PeerReading + mu sync.Mutex +} - interfaceName = os.Getenv("INTERFACE") - remoteConfigURL = os.Getenv("REMOTE_CONFIG") - listenAddr = os.Getenv("LISTEN") - generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") - reachableAt = os.Getenv("REACHABLE_AT") - logLevel = os.Getenv("LOG_LEVEL") - mtu = os.Getenv("MTU") - - if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") - } - if remoteConfigURL == "" { - flag.StringVar(&remoteConfigURL, "remoteConfig", "", "URL to fetch remote configuration") - } - if generateAndSaveKeyTo == "" { - flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") - } - if reachableAt == "" { - flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") - } - if logLevel == "" { - flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") - } - if mtu == "" { - flag.StringVar(&mtu, "mtu", "1280", "MTU of the WireGuard interface") - } - flag.Parse() - - mtuInt, err = strconv.Atoi(mtu) +func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { + wgClient, err := wgctrl.New() if err != nil { - logger.Fatal("Failed to parse MTU: %v", err) + return nil, fmt.Errorf("failed to create WireGuard client: %v", err) } - var key wgtypes.Key + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } + + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + wgClient: wgClient, + key: key, + reachableAt: reachableAt, + generateAndSaveKeyTo: generateAndSaveKeyTo, + lastReadings: make(map[string]PeerReading), + } + + // Register websocket handlers + wsClient.RegisterHandler("wg/peer/config", service.handleConfig) + wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer) + + // Register connect handler to initiate configuration + wsClient.OnConnect(service.handleConnect) + + return service, nil +} + +func (s *WireGuardService) handleConnect() error { + logger.Debug("Public key: %s", s.key.PublicKey()) + + err := s.client.SendMessage("wg/register", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey()), + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Sent registration message") + return nil +} + +func (s *WireGuardService) Start() error { + // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - if generateAndSaveKeyTo != "" { - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - // generate a new private key - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - // save the key to the file - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) - if err != nil { - logger.Fatal("Failed to save private key: %v", err) - } - } else { - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - logger.Fatal("Failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(string(keyData)) - if err != nil { - logger.Fatal("Failed to parse private key: %v", err) - } - } - } else { - // if no generateAndSaveKeyTo is provided, ensure that the private key is provided - if wgconfig.PrivateKey == "" { - // generate a new one - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - } - } - - // loop until we get the config - for wgconfig.PrivateKey == "" { - logger.Info("Fetching remote config from %s", remoteConfigURL) - wgconfig, err = loadRemoteConfig(remoteConfigURL, key, reachableAt) + if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) { + // generate a new private key + s.key, err = wgtypes.GeneratePrivateKey() if err != nil { - logger.Error("Failed to load configuration: %v", err) - time.Sleep(5 * time.Second) - continue + logger.Fatal("Failed to generate private key: %v", err) + } + // save the key to the file + err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644) + if err != nil { + logger.Fatal("Failed to save private key: %v", err) } - wgconfig.PrivateKey = key.String() - } - - wgClient, err = wgctrl.New() - if err != nil { - logger.Fatal("Failed to create WireGuard client: %v", err) - } - defer wgClient.Close() - - // Ensure the WireGuard interface exists and is configured - if err := ensureWireguardInterface(wgconfig); err != nil { - logger.Fatal("Failed to ensure WireGuard interface: %v", err) - } - - // Ensure the WireGuard peers exist - ensureWireguardPeers(wgconfig.Peers) - - // go periodicBandwidthCheck(reportBandwidthTo) -} - -func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { - var body *bytes.Buffer - if reachableAt == "" { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String()))) } else { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) + keyData, err := os.ReadFile(s.generateAndSaveKeyTo) + if err != nil { + logger.Fatal("Failed to read private key: %v", err) + } + s.key, err = wgtypes.ParseKey(string(keyData)) + if err != nil { + logger.Fatal("Failed to parse private key: %v", err) + } } - resp, err := http.Post(url, "application/json", body) + + // Get initial configuration + err := s.loadRemoteConfig() if err != nil { - // print the error - logger.Error("Error fetching remote config %s: %v", url, err) - return WgConfig{}, err - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - return WgConfig{}, err + return fmt.Errorf("failed to load initial configuration: %v", err) } - var config WgConfig - err = json.Unmarshal(data, &config) - - return config, err + return nil } -func ensureWireguardInterface(wgconfig WgConfig) error { +func (s *WireGuardService) StartBandwidthReporting() { + go s.periodicBandwidthCheck() +} + +func (s *WireGuardService) loadRemoteConfig() error { + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) + + // send a ws message to the server to get the config + + err := s.client.SendMessage("wg/config/get", body) + if err != nil { + return fmt.Errorf("failed to send config request: %v", err) + } + + return nil +} + +func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { + var config WgConfig + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + } + + if err := json.Unmarshal(jsonData, &config); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + } + + s.config = config + + // Ensure the WireGuard interface and peers are configured + if err := s.ensureWireguardInterface(config); err != nil { + logger.Error("Failed to ensure WireGuard interface: %v", err) + } + + if err := s.ensureWireguardPeers(config.Peers); err != nil { + logger.Error("Failed to ensure WireGuard peers: %v", err) + } +} + +func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Check if the WireGuard interface exists _, err := netlink.LinkByName(interfaceName) if err != nil { if _, ok := err.(netlink.LinkNotFoundError); ok { // Interface doesn't exist, so create it - err = createWireGuardInterface() + err = s.createWireGuardInterface() if err != nil { logger.Fatal("Failed to create WireGuard interface: %v", err) } @@ -212,7 +210,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { } // Assign IP address to the interface - err = assignIPAddress(wgconfig.IpAddress) + err = s.assignIPAddress(wgconfig.IpAddress) if err != nil { logger.Fatal("Failed to assign IP address: %v", err) } @@ -257,7 +255,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { return fmt.Errorf("failed to bring up interface: %v", err) } - if err := ensureMSSClamping(); err != nil { + if err := s.ensureMSSClamping(); err != nil { logger.Warn("Failed to ensure MSS clamping: %v", err) } @@ -266,7 +264,7 @@ func ensureWireguardInterface(wgconfig WgConfig) error { return nil } -func createWireGuardInterface() error { +func (s *WireGuardService) createWireGuardInterface() error { wgLink := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, LinkType: "wireguard", @@ -274,7 +272,7 @@ func createWireGuardInterface() error { return netlink.LinkAdd(wgLink) } -func assignIPAddress(ipAddress string) error { +func (s *WireGuardService) assignIPAddress(ipAddress string) error { link, err := netlink.LinkByName(interfaceName) if err != nil { return fmt.Errorf("failed to get interface: %v", err) @@ -288,7 +286,7 @@ func assignIPAddress(ipAddress string) error { return netlink.AddrAdd(link, addr) } -func ensureWireguardPeers(peers []Peer) error { +func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { // get the current peers device, err := wgClient.Device(interfaceName) if err != nil { @@ -311,7 +309,7 @@ func ensureWireguardPeers(peers []Peer) error { } } if !found { - err := removePeer(peer) + err := s.removePeer(peer) if err != nil { return fmt.Errorf("failed to remove peer: %v", err) } @@ -328,7 +326,7 @@ func ensureWireguardPeers(peers []Peer) error { } } if !found { - err := addPeer(configPeer) + err := s.addPeer(configPeer) if err != nil { return fmt.Errorf("failed to add peer: %v", err) } @@ -338,7 +336,7 @@ func ensureWireguardPeers(peers []Peer) error { return nil } -func ensureMSSClamping() error { +func (s *WireGuardService) ensureMSSClamping() error { // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) mssValue := mtuInt - 40 @@ -426,7 +424,7 @@ func ensureMSSClamping() error { return nil } -func handleAddPeer(msg websocket.WSMessage) { +func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { var peer Peer jsonData, err := json.Marshal(msg.Data) @@ -438,13 +436,13 @@ func handleAddPeer(msg websocket.WSMessage) { logger.Info("Error unmarshaling target data: %v", err) } - err = addPeer(peer) + err = s.addPeer(peer) if err != nil { return } } -func addPeer(peer Peer) error { +func (s *WireGuardService) addPeer(peer Peer) error { pubKey, err := wgtypes.ParseKey(peer.PublicKey) if err != nil { return fmt.Errorf("failed to parse public key: %v", err) @@ -478,7 +476,7 @@ func addPeer(peer Peer) error { return nil } -func handleRemovePeer(msg websocket.WSMessage) { +func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { // parse the publicKey from the message which is json { "publicKey": "asdfasdfl;akjsdf" } type RemoveRequest struct { PublicKey string `json:"publicKey"` @@ -495,13 +493,13 @@ func handleRemovePeer(msg websocket.WSMessage) { return } - if err := removePeer(request.PublicKey); err != nil { + if err := s.removePeer(request.PublicKey); err != nil { logger.Info("Error removing peer: %v", err) return } } -func removePeer(publicKey string) error { +func (s *WireGuardService) removePeer(publicKey string) error { pubKey, err := wgtypes.ParseKey(publicKey) if err != nil { return fmt.Errorf("failed to parse public key: %v", err) @@ -525,18 +523,18 @@ func removePeer(publicKey string) error { return nil } -func periodicBandwidthCheck(endpoint string) { +func (s *WireGuardService) periodicBandwidthCheck() { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for range ticker.C { - if err := reportPeerBandwidth(endpoint); err != nil { + if err := s.reportPeerBandwidth(); err != nil { logger.Info("Failed to report peer bandwidth: %v", err) } } } -func calculatePeerBandwidth() ([]PeerBandwidth, error) { +func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { device, err := wgClient.Device(interfaceName) if err != nil { return nil, fmt.Errorf("failed to get device: %v", err) @@ -621,8 +619,8 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { return peerBandwidths, nil } -func reportPeerBandwidth(apiURL string) error { - bandwidths, err := calculatePeerBandwidth() +func (s *WireGuardService) reportPeerBandwidth() error { + bandwidths, err := s.calculatePeerBandwidth() if err != nil { return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } @@ -632,15 +630,10 @@ func reportPeerBandwidth(apiURL string) error { return fmt.Errorf("failed to marshal bandwidth data: %v", err) } - resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) + err = s.client.SendMessage("wg/bandwidth", jsonData) if err != nil { return fmt.Errorf("failed to send bandwidth data: %v", err) } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("API returned non-OK status: %s", resp.Status) - } return nil } From 66edae42885341958ca019c9f60cc60c2226d21d Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 20 Feb 2025 21:01:44 -0500 Subject: [PATCH 03/19] Clean up implementation --- main.go | 10 +-- wg/wg.go | 263 ++++++++++++++++++++++++------------------------------- 2 files changed, 116 insertions(+), 157 deletions(-) diff --git a/main.go b/main.go index 9cee671..942acef 100644 --- a/main.go +++ b/main.go @@ -340,15 +340,7 @@ func main() { if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } - // defer wgService.Close() - - // Start the WireGuard service - if err := wgService.Start(); err != nil { - logger.Fatal("Failed to start WireGuard service: %v", err) - } - - // Start bandwidth reporting - wgService.StartBandwidthReporting() + defer wgService.Close() // Create TUN device and network stack var tun tun.Device diff --git a/wg/wg.go b/wg/wg.go index dc5e337..4699ed7 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -6,8 +6,6 @@ import ( "fmt" "net" "os" - "os/exec" - "strings" "sync" "time" @@ -55,16 +53,15 @@ var ( ) type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - wgClient *wgctrl.Client - config WgConfig - key wgtypes.Key - reachableAt string - generateAndSaveKeyTo string - lastReadings map[string]PeerReading - mu sync.Mutex + interfaceName string + mtu int + client *websocket.Client + wgClient *wgctrl.Client + config WgConfig + key wgtypes.Key + reachableAt string + lastReadings map[string]PeerReading + mu sync.Mutex } func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { @@ -73,90 +70,60 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene return nil, fmt.Errorf("failed to create WireGuard client: %v", err) } - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate private key: %v", err) - } - - service := &WireGuardService{ - interfaceName: interfaceName, - mtu: mtu, - client: wsClient, - wgClient: wgClient, - key: key, - reachableAt: reachableAt, - generateAndSaveKeyTo: generateAndSaveKeyTo, - lastReadings: make(map[string]PeerReading), - } - - // Register websocket handlers - wsClient.RegisterHandler("wg/peer/config", service.handleConfig) - wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer) - wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer) - - // Register connect handler to initiate configuration - wsClient.OnConnect(service.handleConnect) - - return service, nil -} - -func (s *WireGuardService) handleConnect() error { - logger.Debug("Public key: %s", s.key.PublicKey()) - - err := s.client.SendMessage("wg/register", map[string]interface{}{ - "publicKey": fmt.Sprintf("%s", s.key.PublicKey()), - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - - logger.Info("Sent registration message") - return nil -} - -func (s *WireGuardService) Start() error { - + key := wgtypes.Key{} // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - if _, err := os.Stat(s.generateAndSaveKeyTo); os.IsNotExist(err) { + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { // generate a new private key - s.key, err = wgtypes.GeneratePrivateKey() + key, err = wgtypes.GeneratePrivateKey() if err != nil { logger.Fatal("Failed to generate private key: %v", err) } // save the key to the file - err = os.WriteFile(s.generateAndSaveKeyTo, []byte(s.key.String()), 0644) + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) if err != nil { logger.Fatal("Failed to save private key: %v", err) } } else { - keyData, err := os.ReadFile(s.generateAndSaveKeyTo) + keyData, err := os.ReadFile(generateAndSaveKeyTo) if err != nil { logger.Fatal("Failed to read private key: %v", err) } - s.key, err = wgtypes.ParseKey(string(keyData)) + key, err = wgtypes.ParseKey(string(keyData)) if err != nil { logger.Fatal("Failed to parse private key: %v", err) } } - // Get initial configuration - err := s.loadRemoteConfig() - if err != nil { - return fmt.Errorf("failed to load initial configuration: %v", err) + service := &WireGuardService{ + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + wgClient: wgClient, + key: key, + reachableAt: reachableAt, + lastReadings: make(map[string]PeerReading), } - return nil + // Register websocket handlers + wsClient.RegisterHandler("wg/config/receive", service.handleConfig) + wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer) + + // Register connect handler to initiate configuration + wsClient.OnConnect(service.loadRemoteConfig) + + return service, nil } -func (s *WireGuardService) StartBandwidthReporting() { - go s.periodicBandwidthCheck() +func (s *WireGuardService) Close() { + s.client.Close() + wgClient.Close() } func (s *WireGuardService) loadRemoteConfig() error { - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) - // send a ws message to the server to get the config + go s.periodicBandwidthCheck() err := s.client.SendMessage("wg/config/get", body) if err != nil { @@ -255,9 +222,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { return fmt.Errorf("failed to bring up interface: %v", err) } - if err := s.ensureMSSClamping(); err != nil { - logger.Warn("Failed to ensure MSS clamping: %v", err) - } + // if err := s.ensureMSSClamping(); err != nil { + // logger.Warn("Failed to ensure MSS clamping: %v", err) + // } logger.Info("WireGuard interface %s created and configured", interfaceName) @@ -336,93 +303,93 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { return nil } -func (s *WireGuardService) ensureMSSClamping() error { - // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) - mssValue := mtuInt - 40 +// func (s *WireGuardService) ensureMSSClamping() error { +// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) +// mssValue := mtuInt - 40 - // Rules to be managed - just the chains, we'll construct the full command separately - chains := []string{"INPUT", "OUTPUT", "FORWARD"} +// // Rules to be managed - just the chains, we'll construct the full command separately +// chains := []string{"INPUT", "OUTPUT", "FORWARD"} - // First, try to delete any existing rules - for _, chain := range chains { - deleteCmd := exec.Command("/usr/sbin/iptables", - "-t", "mangle", - "-D", chain, - "-p", "tcp", - "--tcp-flags", "SYN,RST", "SYN", - "-j", "TCPMSS", - "--set-mss", fmt.Sprintf("%d", mssValue)) +// // First, try to delete any existing rules +// for _, chain := range chains { +// deleteCmd := exec.Command("/usr/sbin/iptables", +// "-t", "mangle", +// "-D", chain, +// "-p", "tcp", +// "--tcp-flags", "SYN,RST", "SYN", +// "-j", "TCPMSS", +// "--set-mss", fmt.Sprintf("%d", mssValue)) - logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) +// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) - // Try deletion multiple times to handle multiple existing rules - for i := 0; i < 3; i++ { - out, err := deleteCmd.CombinedOutput() - if err != nil { - // Convert exit status 1 to string for better logging - if exitErr, ok := err.(*exec.ExitError); ok { - logger.Debug("Deletion stopped for chain %s: %v (output: %s)", - chain, exitErr.String(), string(out)) - } - break // No more rules to delete - } - logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) - } - } +// // Try deletion multiple times to handle multiple existing rules +// for i := 0; i < 3; i++ { +// out, err := deleteCmd.CombinedOutput() +// if err != nil { +// // Convert exit status 1 to string for better logging +// if exitErr, ok := err.(*exec.ExitError); ok { +// logger.Debug("Deletion stopped for chain %s: %v (output: %s)", +// chain, exitErr.String(), string(out)) +// } +// break // No more rules to delete +// } +// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) +// } +// } - // Then add the new rules - var errors []error - for _, chain := range chains { - addCmd := exec.Command("/usr/sbin/iptables", - "-t", "mangle", - "-A", chain, - "-p", "tcp", - "--tcp-flags", "SYN,RST", "SYN", - "-j", "TCPMSS", - "--set-mss", fmt.Sprintf("%d", mssValue)) +// // Then add the new rules +// var errors []error +// for _, chain := range chains { +// addCmd := exec.Command("/usr/sbin/iptables", +// "-t", "mangle", +// "-A", chain, +// "-p", "tcp", +// "--tcp-flags", "SYN,RST", "SYN", +// "-j", "TCPMSS", +// "--set-mss", fmt.Sprintf("%d", mssValue)) - logger.Info("Adding MSS clamping rule for chain %s", chain) +// logger.Info("Adding MSS clamping rule for chain %s", chain) - if out, err := addCmd.CombinedOutput(); err != nil { - errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", - chain, err, string(out)) - logger.Error(errMsg) - errors = append(errors, fmt.Errorf(errMsg)) - continue - } +// if out, err := addCmd.CombinedOutput(); err != nil { +// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", +// chain, err, string(out)) +// logger.Error(errMsg) +// errors = append(errors, fmt.Errorf(errMsg)) +// continue +// } - // Verify the rule was added - checkCmd := exec.Command("/usr/sbin/iptables", - "-t", "mangle", - "-C", chain, - "-p", "tcp", - "--tcp-flags", "SYN,RST", "SYN", - "-j", "TCPMSS", - "--set-mss", fmt.Sprintf("%d", mssValue)) +// // Verify the rule was added +// checkCmd := exec.Command("/usr/sbin/iptables", +// "-t", "mangle", +// "-C", chain, +// "-p", "tcp", +// "--tcp-flags", "SYN,RST", "SYN", +// "-j", "TCPMSS", +// "--set-mss", fmt.Sprintf("%d", mssValue)) - if out, err := checkCmd.CombinedOutput(); err != nil { - errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", - chain, err, string(out)) - logger.Error(errMsg) - errors = append(errors, fmt.Errorf(errMsg)) - continue - } +// if out, err := checkCmd.CombinedOutput(); err != nil { +// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", +// chain, err, string(out)) +// logger.Error(errMsg) +// errors = append(errors, fmt.Errorf(errMsg)) +// continue +// } - logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) - } +// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) +// } - // If we encountered any errors, return them combined - if len(errors) > 0 { - var errMsgs []string - for _, err := range errors { - errMsgs = append(errMsgs, err.Error()) - } - return fmt.Errorf("MSS clamping setup encountered errors:\n%s", - strings.Join(errMsgs, "\n")) - } +// // If we encountered any errors, return them combined +// if len(errors) > 0 { +// var errMsgs []string +// for _, err := range errors { +// errMsgs = append(errMsgs, err.Error()) +// } +// return fmt.Errorf("MSS clamping setup encountered errors:\n%s", +// strings.Join(errMsgs, "\n")) +// } - return nil -} +// return nil +// } func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { var peer Peer From fb199cc94be248891effff07cabd42c35415f7df Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 20 Feb 2025 22:07:27 -0500 Subject: [PATCH 04/19] Tidy --- go.mod | 15 ++++++++------- go.sum | 12 ++---------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 7812b1b..08b4716 100644 --- a/go.mod +++ b/go.mod @@ -4,25 +4,26 @@ go 1.23.1 toolchain go1.23.2 -require golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 +require ( + github.com/gorilla/websocket v1.5.3 + github.com/vishvananda/netlink v1.3.0 + golang.org/x/net v0.33.0 + golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 + golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 +) require ( github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect - github.com/vishvananda/netlink v1.3.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/crypto v0.31.0 // indirect - golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect - golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 // indirect - gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect ) diff --git a/go.sum b/go.sum index f453d4f..2f56ede 100644 --- a/go.sum +++ b/go.sum @@ -12,26 +12,20 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= @@ -40,8 +34,6 @@ golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeu golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= -golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= From 45a1ab91d7f68acaf223aef4d619d5ebcf26bc1d Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 20 Feb 2025 22:10:02 -0500 Subject: [PATCH 05/19] Dont always do wg --- main.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/main.go b/main.go index 942acef..29e25ee 100644 --- a/main.go +++ b/main.go @@ -335,12 +335,14 @@ func main() { logger.Fatal("Failed to create client: %v", err) } - // Create WireGuard service - wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) - if err != nil { - logger.Fatal("Failed to create WireGuard service: %v", err) + if reachableAt != "" { + // Create WireGuard service + wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + defer wgService.Close() } - defer wgService.Close() // Create TUN device and network stack var tun tun.Device @@ -417,7 +419,7 @@ func main() { public_key=%s allowed_ip=%s/32 endpoint=%s -persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) +persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) err = dev.IpcSet(config) if err != nil { @@ -549,7 +551,7 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( logger.Debug("Public key: %s", publicKey) err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": fmt.Sprintf("%s", publicKey), + "publicKey": publicKey.PublicKey(), }) if err != nil { logger.Error("Failed to send registration message: %v", err) From 56e75902e3941db8bbeac406fae0065d6ede55aa Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 12:44:52 -0500 Subject: [PATCH 06/19] Adjust ws types --- main.go | 2 +- wg/wg.go | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/main.go b/main.go index 29e25ee..0592a51 100644 --- a/main.go +++ b/main.go @@ -55,7 +55,7 @@ func fixKey(key string) string { // Decode from base64 decoded, err := base64.StdEncoding.DecodeString(key) if err != nil { - logger.Fatal("Error decoding base64:", err) + logger.Fatal("Error decoding base64") } // Convert to hex diff --git a/wg/wg.go b/wg/wg.go index 4699ed7..7e19958 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -18,7 +18,6 @@ import ( var ( interfaceName string - listenAddr string mtuInt int lastReadings = make(map[string]PeerReading) mu sync.Mutex @@ -61,7 +60,6 @@ type WireGuardService struct { key wgtypes.Key reachableAt string lastReadings map[string]PeerReading - mu sync.Mutex } func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { @@ -70,7 +68,7 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene return nil, fmt.Errorf("failed to create WireGuard client: %v", err) } - key := wgtypes.Key{} + var key wgtypes.Key // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { // generate a new private key @@ -105,9 +103,9 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene } // Register websocket handlers - wsClient.RegisterHandler("wg/config/receive", service.handleConfig) - wsClient.RegisterHandler("wg/peer/add", service.handleAddPeer) - wsClient.RegisterHandler("wg/peer/remove", service.handleRemovePeer) + wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) + wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) + wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) // Register connect handler to initiate configuration wsClient.OnConnect(service.loadRemoteConfig) @@ -121,11 +119,11 @@ func (s *WireGuardService) Close() { } func (s *WireGuardService) loadRemoteConfig() error { - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "endpoint": "%s"}`, s.key.PublicKey().String(), s.reachableAt))) + body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt))) go s.periodicBandwidthCheck() - err := s.client.SendMessage("wg/config/get", body) + err := s.client.SendMessage("newt/wg/get-config", body) if err != nil { return fmt.Errorf("failed to send config request: %v", err) } From 95eab504fac0c87077d75152c03c149c3fc23efd Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 16:12:12 -0500 Subject: [PATCH 07/19] Get wg working --- main.go | 29 ++++++++++-------- wg/wg.go | 90 ++++++++++++++++++++++++-------------------------------- 2 files changed, 56 insertions(+), 63 deletions(-) diff --git a/main.go b/main.go index 0592a51..1f0b289 100644 --- a/main.go +++ b/main.go @@ -291,7 +291,7 @@ func main() { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg-1", "Name of the WireGuard interface") + flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") @@ -335,15 +335,7 @@ func main() { logger.Fatal("Failed to create client: %v", err) } - if reachableAt != "" { - // Create WireGuard service - wgService, err := wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) - if err != nil { - logger.Fatal("Failed to create WireGuard service: %v", err) - } - defer wgService.Close() - } - + var wgService *wg.WireGuardService // Create TUN device and network stack var tun tun.Device var tnet *netstack.Net @@ -352,6 +344,16 @@ func main() { var connected bool var wgData WgData + if reachableAt != "" { + logger.Info("Sending reachableAt to server: %s", reachableAt) + // Create WireGuard service + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + defer wgService.Close() + } + client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") if pm != nil { @@ -419,7 +421,7 @@ func main() { public_key=%s allowed_ip=%s/32 endpoint=%s -persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) +persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) err = dev.IpcSet(config) if err != nil { @@ -439,6 +441,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub if err != nil { // Handle complete failure after all retries logger.Error("Failed to ping %s: %v", wgData.ServerIP, err) + fmt.Sprintf("%s", privateKey) } if !connected { @@ -551,13 +554,15 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Debug("Public key: %s", publicKey) err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": publicKey.PublicKey(), + "publicKey": fmt.Sprintf("%s", publicKey), }) if err != nil { logger.Error("Failed to send registration message: %v", err) return err } + wgService.LoadRemoteConfig() + logger.Info("Sent registration message") return nil }) diff --git a/wg/wg.go b/wg/wg.go index 7e19958..6df69a7 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -1,7 +1,6 @@ package wg import ( - "bytes" "encoding/json" "fmt" "net" @@ -16,13 +15,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - interfaceName string - mtuInt int - lastReadings = make(map[string]PeerReading) - mu sync.Mutex -) - type WgConfig struct { PrivateKey string `json:"privateKey"` ListenPort int `json:"listenPort"` @@ -47,10 +39,6 @@ type PeerReading struct { LastChecked time.Time } -var ( - wgClient *wgctrl.Client -) - type WireGuardService struct { interfaceName string mtu int @@ -60,6 +48,7 @@ type WireGuardService struct { key wgtypes.Key reachableAt string lastReadings map[string]PeerReading + mu sync.Mutex } func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { @@ -107,27 +96,29 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) wsClient.RegisterHandler("newt/wg/peer/remove", service.handleRemovePeer) - // Register connect handler to initiate configuration - wsClient.OnConnect(service.loadRemoteConfig) - return service, nil } func (s *WireGuardService) Close() { s.client.Close() - wgClient.Close() + s.wgClient.Close() } -func (s *WireGuardService) loadRemoteConfig() error { - body := bytes.NewBuffer([]byte(fmt.Sprintf(`{ "publicKey": "%s", "endpoint": "%s" }`, s.key.PublicKey().String(), s.reachableAt))) +func (s *WireGuardService) LoadRemoteConfig() error { + + err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ + "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), + "endpoint": s.reachableAt, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Requesting WireGuard configuration from remote server") go s.periodicBandwidthCheck() - err := s.client.SendMessage("newt/wg/get-config", body) - if err != nil { - return fmt.Errorf("failed to send config request: %v", err) - } - return nil } @@ -157,7 +148,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Check if the WireGuard interface exists - _, err := netlink.LinkByName(interfaceName) + _, err := netlink.LinkByName(s.interfaceName) if err != nil { if _, ok := err.(netlink.LinkNotFoundError); ok { // Interface doesn't exist, so create it @@ -165,12 +156,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { if err != nil { logger.Fatal("Failed to create WireGuard interface: %v", err) } - logger.Info("Created WireGuard interface %s\n", interfaceName) + logger.Info("Created WireGuard interface %s\n", s.interfaceName) } else { logger.Fatal("Error checking for WireGuard interface: %v", err) } } else { - logger.Info("WireGuard interface %s already exists\n", interfaceName) + logger.Info("WireGuard interface %s already exists\n", s.interfaceName) return nil } @@ -179,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { if err != nil { logger.Fatal("Failed to assign IP address: %v", err) } - logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, interfaceName) + logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) // Check if the interface already exists - _, err = wgClient.Device(interfaceName) + _, err = s.wgClient.Device(s.interfaceName) if err != nil { - return fmt.Errorf("interface %s does not exist", interfaceName) + return fmt.Errorf("interface %s does not exist", s.interfaceName) } // Parse the private key @@ -201,18 +192,18 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { *config.ListenPort = wgconfig.ListenPort // Create and configure the WireGuard interface - err = wgClient.ConfigureDevice(interfaceName, config) + err = s.wgClient.ConfigureDevice(s.interfaceName, config) if err != nil { return fmt.Errorf("failed to configure WireGuard device: %v", err) } // bring up the interface - link, err := netlink.LinkByName(interfaceName) + link, err := netlink.LinkByName(s.interfaceName) if err != nil { return fmt.Errorf("failed to get interface: %v", err) } - if err := netlink.LinkSetMTU(link, mtuInt); err != nil { + if err := netlink.LinkSetMTU(link, s.mtu); err != nil { return fmt.Errorf("failed to set MTU: %v", err) } @@ -224,21 +215,21 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // logger.Warn("Failed to ensure MSS clamping: %v", err) // } - logger.Info("WireGuard interface %s created and configured", interfaceName) + logger.Info("WireGuard interface %s created and configured", s.interfaceName) return nil } func (s *WireGuardService) createWireGuardInterface() error { wgLink := &netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{Name: interfaceName}, + LinkAttrs: netlink.LinkAttrs{Name: s.interfaceName}, LinkType: "wireguard", } return netlink.LinkAdd(wgLink) } func (s *WireGuardService) assignIPAddress(ipAddress string) error { - link, err := netlink.LinkByName(interfaceName) + link, err := netlink.LinkByName(s.interfaceName) if err != nil { return fmt.Errorf("failed to get interface: %v", err) } @@ -253,7 +244,7 @@ func (s *WireGuardService) assignIPAddress(ipAddress string) error { func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { // get the current peers - device, err := wgClient.Device(interfaceName) + device, err := s.wgClient.Device(s.interfaceName) if err != nil { return fmt.Errorf("failed to get device: %v", err) } @@ -432,7 +423,7 @@ func (s *WireGuardService) addPeer(peer Peer) error { Peers: []wgtypes.PeerConfig{peerConfig}, } - if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { return fmt.Errorf("failed to add peer: %v", err) } @@ -479,7 +470,7 @@ func (s *WireGuardService) removePeer(publicKey string) error { Peers: []wgtypes.PeerConfig{peerConfig}, } - if err := wgClient.ConfigureDevice(interfaceName, config); err != nil { + if err := s.wgClient.ConfigureDevice(s.interfaceName, config); err != nil { return fmt.Errorf("failed to remove peer: %v", err) } @@ -500,7 +491,7 @@ func (s *WireGuardService) periodicBandwidthCheck() { } func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { - device, err := wgClient.Device(interfaceName) + device, err := s.wgClient.Device(s.interfaceName) if err != nil { return nil, fmt.Errorf("failed to get device: %v", err) } @@ -508,8 +499,8 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { peerBandwidths := []PeerBandwidth{} now := time.Now() - mu.Lock() - defer mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() for _, peer := range device.Peers { publicKey := peer.PublicKey.String() @@ -520,7 +511,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { } var bytesInDiff, bytesOutDiff float64 - lastReading, exists := lastReadings[publicKey] + lastReading, exists := s.lastReadings[publicKey] if exists { timeDiff := currentReading.LastChecked.Sub(lastReading.LastChecked).Seconds() @@ -564,11 +555,11 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { } // Update the last reading - lastReadings[publicKey] = currentReading + s.lastReadings[publicKey] = currentReading } // Clean up old peers - for publicKey := range lastReadings { + for publicKey := range s.lastReadings { found := false for _, peer := range device.Peers { if peer.PublicKey.String() == publicKey { @@ -577,7 +568,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) { } } if !found { - delete(lastReadings, publicKey) + delete(s.lastReadings, publicKey) } } @@ -590,12 +581,9 @@ func (s *WireGuardService) reportPeerBandwidth() error { return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } - jsonData, err := json.Marshal(bandwidths) - if err != nil { - return fmt.Errorf("failed to marshal bandwidth data: %v", err) - } - - err = s.client.SendMessage("wg/bandwidth", jsonData) + err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ + "bandwidthData": bandwidths, + }) if err != nil { return fmt.Errorf("failed to send bandwidth data: %v", err) } From bff6707577a830860a10aa174065b00057d909f7 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 16:20:03 -0500 Subject: [PATCH 08/19] Basic create wg seems to be working --- main.go | 2 +- websocket/client.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 1f0b289..8e81054 100644 --- a/main.go +++ b/main.go @@ -291,7 +291,7 @@ func main() { flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") } if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") + flag.StringVar(&interfaceName, "interface", "wg1", "Name of the WireGuard interface") } if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") diff --git a/websocket/client.go b/websocket/client.go index 8a7d3f9..08b9167 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -288,6 +288,7 @@ func (c *Client) establishConnection() error { // Add token to query parameters q := u.Query() q.Set("token", token) + q.Set("clientType", "newt") u.RawQuery = q.Encode() // Connect to WebSocket From 18d99de924309bf91ba5c0684fd64463d11cdeeb Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 17:13:00 -0500 Subject: [PATCH 09/19] Handle messages correctly --- wg/wg.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/wg/wg.go b/wg/wg.go index 6df69a7..da08f84 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -125,16 +125,21 @@ func (s *WireGuardService) LoadRemoteConfig() error { func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { var config WgConfig + logger.Info("Received WireGuard configuration") + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) + return } if err := json.Unmarshal(jsonData, &config); err != nil { logger.Info("Error unmarshaling target data: %v", err) + return } s.config = config + logger.Info("Config: %v", s.config) // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { @@ -165,12 +170,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { return nil } + logger.Info("Assigning IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) // Assign IP address to the interface err = s.assignIPAddress(wgconfig.IpAddress) if err != nil { logger.Fatal("Failed to assign IP address: %v", err) } - logger.Info("Assigned IP address %s to interface %s\n", wgconfig.IpAddress, s.interfaceName) // Check if the interface already exists _, err = s.wgClient.Device(s.interfaceName) From 0affef401c0a9994d11e83eea7fb2d903ca8c430 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 18:04:36 -0500 Subject: [PATCH 10/19] Properly handle key --- wg/wg.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/wg/wg.go b/wg/wg.go index da08f84..9f962b5 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -16,7 +16,6 @@ import ( ) type WgConfig struct { - PrivateKey string `json:"privateKey"` ListenPort int `json:"listenPort"` IpAddress string `json:"ipAddress"` Peers []Peer `json:"peers"` @@ -125,7 +124,7 @@ func (s *WireGuardService) LoadRemoteConfig() error { func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { var config WgConfig - logger.Info("Received WireGuard configuration") + logger.Info("Received message: %v", msg) jsonData, err := json.Marshal(msg.Data) if err != nil { @@ -137,9 +136,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { logger.Info("Error unmarshaling target data: %v", err) return } - s.config = config - logger.Info("Config: %v", s.config) // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { @@ -184,7 +181,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { } // Parse the private key - key, err := wgtypes.ParseKey(wgconfig.PrivateKey) + key, err := wgtypes.ParseKey(s.key.String()) if err != nil { return fmt.Errorf("failed to parse private key: %v", err) } From 270ee9cd190d3a16faa28dbbef833a095a8c4b60 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 20:33:31 -0500 Subject: [PATCH 11/19] Fix panic --- wg/wg.go | 1 - 1 file changed, 1 deletion(-) diff --git a/wg/wg.go b/wg/wg.go index 9f962b5..4f388c1 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -99,7 +99,6 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene } func (s *WireGuardService) Close() { - s.client.Close() s.wgClient.Close() } From afa93d8a3fbeb2821c56fc06107101703f113b10 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 21 Feb 2025 22:27:24 -0500 Subject: [PATCH 12/19] Add static port and udp hole punch --- main.go | 2 +- wg/wg.go | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index 8e81054..139da58 100644 --- a/main.go +++ b/main.go @@ -347,7 +347,7 @@ func main() { if reachableAt != "" { logger.Info("Sending reachableAt to server: %s", reachableAt) // Create WireGuard service - wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, client) + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, endpoint, id, client) if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } diff --git a/wg/wg.go b/wg/wg.go index 4f388c1..9b3a137 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -11,6 +11,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/websocket" "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -46,11 +47,31 @@ type WireGuardService struct { config WgConfig key wgtypes.Key reachableAt string + newtId string lastReadings map[string]PeerReading mu sync.Mutex + port uint16 } -func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, wsClient *websocket.Client) (*WireGuardService, error) { +// Add this type definition +type fixedPortBind struct { + port uint16 + conn.Bind +} + +func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + // Ignore the requested port and use our fixed port + return b.Bind.Open(b.port) +} + +func NewFixedPortBind(port uint16) conn.Bind { + return &fixedPortBind{ + port: port, + Bind: conn.NewDefaultBind(), + } +} + +func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, endpoint string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { wgClient, err := wgctrl.New() if err != nil { return nil, fmt.Errorf("failed to create WireGuard client: %v", err) @@ -87,7 +108,14 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene wgClient: wgClient, key: key, reachableAt: reachableAt, + newtId: newtId, lastReadings: make(map[string]PeerReading), + port: 21821, + } + + if err := service.sendUDPHolePunch(endpoint + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + // Continue anyway as this is just for NAT traversal } // Register websocket handlers @@ -185,12 +213,13 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { return fmt.Errorf("failed to parse private key: %v", err) } - // Create a new WireGuard configuration config := wgtypes.Config{ PrivateKey: &key, ListenPort: new(int), } - *config.ListenPort = wgconfig.ListenPort + + // Use the service's fixed port instead of the config port + *config.ListenPort = int(s.port) // Create and configure the WireGuard interface err = s.wgClient.ConfigureDevice(s.interfaceName, config) @@ -591,3 +620,40 @@ func (s *WireGuardService) reportPeerBandwidth() error { return nil } + +func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { + // Bind to specific local port + localAddr := &net.UDPAddr{ + Port: int(s.port), + IP: net.IPv4zero, + } + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %v", err) + } + + conn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return fmt.Errorf("failed to bind UDP socket: %v", err) + } + defer conn.Close() + + payload := struct { + NewtID string `json:"newtId"` + }{ + NewtID: s.newtId, + } + + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + _, err = conn.WriteToUDP(data, remoteAddr) + if err != nil { + return fmt.Errorf("failed to send UDP packet: %v", err) + } + + return nil +} From 4aa718d55f27e2a9f7394fe3409bf5c44249e408 Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 22 Feb 2025 11:21:13 -0500 Subject: [PATCH 13/19] Initial hp working but need to fix port issue --- main.go | 12 ++++++--- wg/wg.go | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 77 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index 139da58..e08eefd 100644 --- a/main.go +++ b/main.go @@ -344,10 +344,16 @@ func main() { var connected bool var wgData WgData - if reachableAt != "" { - logger.Info("Sending reachableAt to server: %s", reachableAt) + if generateAndSaveKeyTo != "" { + var host = endpoint + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + // Create WireGuard service - wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, endpoint, id, client) + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, host, id, client) if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } diff --git a/wg/wg.go b/wg/wg.go index 9b3a137..58bf02a 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -25,6 +25,7 @@ type WgConfig struct { type Peer struct { PublicKey string `json:"publicKey"` AllowedIPs []string `json:"allowedIps"` + Endpoint string `json:"endpoint"` } type PeerBandwidth struct { @@ -51,6 +52,7 @@ type WireGuardService struct { lastReadings map[string]PeerReading mu sync.Mutex port uint16 + stopHolepunch chan struct{} } // Add this type definition @@ -71,7 +73,35 @@ func NewFixedPortBind(port uint16) conn.Bind { } } -func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, endpoint string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { +func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { + if maxPort < minPort { + return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) + } + + for port := minPort; port <= maxPort; port++ { + // Create the UDP address to test + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: int(port), + } + + // Attempt to create a UDP listener + conn, err := net.ListenUDP("udp", addr) + if err != nil { + continue // Port is in use or there was an error, try next port + } + + // Close the connection immediately + _ = conn.SetDeadline(time.Now()) + conn.Close() + + return port, nil + } + + return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) +} + +func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { wgClient, err := wgctrl.New() if err != nil { return nil, fmt.Errorf("failed to create WireGuard client: %v", err) @@ -101,6 +131,12 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene } } + port, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + fmt.Printf("Error finding available port: %v\n", err) + return nil, err + } + service := &WireGuardService{ interfaceName: interfaceName, mtu: mtu, @@ -110,13 +146,12 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene reachableAt: reachableAt, newtId: newtId, lastReadings: make(map[string]PeerReading), - port: 21821, + port: port, + stopHolepunch: make(chan struct{}), } - if err := service.sendUDPHolePunch(endpoint + ":21820"); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - // Continue anyway as this is just for NAT traversal - } + // start the UDP holepunch + go service.keepSendingUDPHolePunch(host) // Register websocket handlers wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) @@ -443,10 +478,18 @@ func (s *WireGuardService) addPeer(peer Peer) error { } allowedIPs = append(allowedIPs, *ipNet) } + // add keep alive using *time.Duration of 1 second + keepalive := time.Second + endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint address: %w", err) + } peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - AllowedIPs: allowedIPs, + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + Endpoint: endpoint, } config := wgtypes.Config{ @@ -657,3 +700,20 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return nil } + +func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-ticker.C: + if err := s.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +} From 8795c57b2e806a2e4e71ea026f920a7e0006b00a Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 22 Feb 2025 12:53:23 -0500 Subject: [PATCH 14/19] HP works! --- go.mod | 3 ++- go.sum | 4 ++++ wg/wg.go | 32 +++++++++++++++++++++++++------- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 08b4716..c6931ef 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,8 @@ require ( github.com/mdlayher/socket v0.5.1 // indirect github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/crypto v0.31.0 // indirect - golang.org/x/sync v0.10.0 // indirect + golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect + golang.org/x/sync v0.11.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/time v0.7.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 2f56ede..a0deda0 100644 --- a/go.sum +++ b/go.sum @@ -20,10 +20,14 @@ github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1Y github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= +golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= diff --git a/wg/wg.go b/wg/wg.go index 58bf02a..6883ca9 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -11,6 +11,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/websocket" "github.com/vishvananda/netlink" + "golang.org/x/exp/rand" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -78,23 +79,31 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) } - for port := minPort; port <= maxPort; port++ { - // Create the UDP address to test + // Create a slice of all ports in the range + portRange := make([]uint16, maxPort-minPort+1) + for i := range portRange { + portRange[i] = minPort + uint16(i) + } + + // Fisher-Yates shuffle to randomize the port order + rand.Seed(uint64(time.Now().UnixNano())) + for i := len(portRange) - 1; i > 0; i-- { + j := rand.Intn(i + 1) + portRange[i], portRange[j] = portRange[j], portRange[i] + } + + // Try each port in the randomized order + for _, port := range portRange { addr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: int(port), } - - // Attempt to create a UDP listener conn, err := net.ListenUDP("udp", addr) if err != nil { continue // Port is in use or there was an error, try next port } - - // Close the connection immediately _ = conn.SetDeadline(time.Now()) conn.Close() - return port, nil } @@ -150,6 +159,10 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene stopHolepunch: make(chan struct{}), } + if err := service.sendUDPHolePunch(host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + // start the UDP holepunch go service.keepSendingUDPHolePunch(host) @@ -200,6 +213,9 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { } s.config = config + // stop the holepunch + close(s.stopHolepunch) + // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { logger.Error("Failed to ensure WireGuard interface: %v", err) @@ -698,6 +714,8 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return fmt.Errorf("failed to send UDP packet: %v", err) } + logger.Info("Sent UDP hole punch to %s", serverAddr) + return nil } From f6429b6eeeb91db5b308600998cfa06560bd9168 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 23 Feb 2025 00:59:51 -0500 Subject: [PATCH 15/19] Basic holepunch working --- go.mod | 1 + go.sum | 14 ++++ network/network.go | 202 +++++++++++++++++++++++++++++++++++++++++++++ nohup.out | 25 ++++++ wg/wg.go | 159 ++++++++++++----------------------- 5 files changed, 295 insertions(+), 106 deletions(-) create mode 100644 network/network.go create mode 100644 nohup.out diff --git a/go.mod b/go.mod index c6931ef..c9d2752 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( require ( github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect + github.com/google/gopacket v1.1.19 // indirect github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect diff --git a/go.sum b/go.sum index a0deda0..5e6875a 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= @@ -18,22 +20,34 @@ github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQ github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4= golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= diff --git a/network/network.go b/network/network.go new file mode 100644 index 0000000..0703e8b --- /dev/null +++ b/network/network.go @@ -0,0 +1,202 @@ +package network + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "log" + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/vishvananda/netlink" + "golang.org/x/net/bpf" + "golang.org/x/net/ipv4" +) + +const ( + udpProtocol = 17 + // EmptyUDPSize is the size of an empty UDP packet + EmptyUDPSize = 28 + timeout = time.Second * 10 +) + +// Server stores data relating to the server +type Server struct { + Hostname string + Addr *net.IPAddr + Port uint16 +} + +// PeerNet stores data about a peer's endpoint +type PeerNet struct { + Resolved bool + IP net.IP + Port uint16 + NewtID string +} + +// GetClientIP gets source ip address that will be used when sending data to dstIP +func GetClientIP(dstIP net.IP) net.IP { + routes, err := netlink.RouteGet(dstIP) + if err != nil { + log.Fatalln("Error getting route:", err) + } + return routes[0].Src +} + +// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr +func HostToAddr(hostStr string) *net.IPAddr { + remoteAddrs, err := net.LookupHost(hostStr) + if err != nil { + log.Fatalln("Error parsing remote address:", err) + } + + for _, addrStr := range remoteAddrs { + if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { + return remoteAddr + } + } + return nil +} + +// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering +func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn { + packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) + if err != nil { + log.Fatalln("Error creating packetConn:", err) + } + + rawConn, err := ipv4.NewRawConn(packetConn) + if err != nil { + log.Fatalln("Error creating rawConn:", err) + } + + ApplyBPF(rawConn, server, client) + + return rawConn +} + +// ApplyBPF constructs a BPF program and applies it to the RawConn +func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) { + const ipv4HeaderLen = 20 + const srcIPOffset = 12 + const srcPortOffset = ipv4HeaderLen + 0 + const dstPortOffset = ipv4HeaderLen + 2 + + ipArr := []byte(server.Addr.IP.To4()) + ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) + + bpfRaw, err := bpf.Assemble([]bpf.Instruction{ + bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, + + bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, + + bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, + bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, + + bpf.RetConstant{Val: 1<<(8*4) - 1}, + bpf.RetConstant{Val: 0}, + }) + + if err != nil { + log.Fatalln("Error assembling BPF:", err) + } + + err = rawConn.SetBPF(bpfRaw) + if err != nil { + log.Fatalln("Error setting BPF:", err) + } +} + +// MakePacket constructs a request packet to send to the server +func MakePacket(payload []byte, server *Server, client *PeerNet) []byte { + buf := gopacket.NewSerializeBuffer() + + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + ipHeader := layers.IPv4{ + SrcIP: client.IP, + DstIP: server.Addr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + + udpHeader := layers.UDP{ + SrcPort: layers.UDPPort(client.Port), + DstPort: layers.UDPPort(server.Port), + } + + payloadLayer := gopacket.Payload(payload) + + udpHeader.SetNetworkLayerForChecksum(&ipHeader) + + gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) + + return buf.Bytes() +} + +// SendPacket sends packet to the Server +func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error { + fullPacket := MakePacket(packet, server, client) + _, err := conn.WriteToIP(fullPacket, server.Addr) + return err +} + +// SendDataPacket sends a JSON payload to the Server +func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal payload: %v", err) + } + + return SendPacket(jsonData, conn, server, client) +} + +// RecvPacket receives a UDP packet from server +func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) { + err := conn.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + return nil, 0, err + } + + response := make([]byte, 4096) + n, err := conn.Read(response) + if err != nil { + return nil, n, err + } + return response, n, nil +} + +// RecvDataPacket receives and unmarshals a JSON packet from server +func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) { + response, n, err := RecvPacket(conn, server, client) + if err != nil { + return nil, err + } + + // Extract payload from UDP packet + payload := response[EmptyUDPSize:n] + return payload, nil +} + +// ParseResponse takes a response packet and parses it into an IP and port +func ParseResponse(response []byte) (net.IP, uint16) { + ip := net.IP(response[:4]) + port := binary.BigEndian.Uint16(response[4:6]) + return ip, port +} + +func parseForBPF(response []byte) (srcIP net.IP, srcPort uint16, dstPort uint16) { + srcIP = net.IP(response[12:16]) + srcPort = binary.BigEndian.Uint16(response[20:22]) + dstPort = binary.BigEndian.Uint16(response[22:24]) + return +} diff --git a/nohup.out b/nohup.out new file mode 100644 index 0000000..58bc6f3 --- /dev/null +++ b/nohup.out @@ -0,0 +1,25 @@ +INFO: 2025/02/22 23:25:47 Requesting WireGuard configuration from remote server +INFO: 2025/02/22 23:25:47 Sent registration message +INFO: 2025/02/22 23:25:47 Received message: {newt/wg/receive-config map[ipAddress:100.90.128.1/24 listenPort:51822 peers:[]]} +INFO: 2025/02/22 23:25:47 Created WireGuard interface wg1 +INFO: 2025/02/22 23:25:47 Assigning IP address 100.90.128.1/24 to interface wg1 +INFO: 2025/02/22 23:25:47 WireGuard interface wg1 created and configured +INFO: 2025/02/22 23:25:47 Received registration message +INFO: 2025/02/22 23:25:47 Received: {Type:newt/wg/connect Data:map[endpoint:pangolin.fosrl.io:51820 publicKey:tng9Z/BN32flFjqwwT1yAxN/twFkmgbZA+D9N+YqdjM= serverIP:100.89.128.1 targets:map[tcp:[] udp:[]] tunnelIP:100.89.128.4]} +INFO: 2025/02/22 23:25:47 WireGuard device created. Lets ping the server now... +INFO: 2025/02/22 23:25:47 Ping attempt 1 of 5 +INFO: 2025/02/22 23:25:47 Pinging 100.89.128.1 +INFO: 2025/02/22 23:25:47 Ping latency: 9.00105ms +INFO: 2025/02/22 23:25:47 Starting ping check +INFO: 2025/02/22 23:26:48 Peer P9pacnRfUSfvDibaQTdTk59q27eRpgtbMMmMpkNwKl0= removed successfully +INFO: 2025/02/22 23:26:48 Peer NMrcorGgTTi4tAUZ1lLru0qISNrt9D9JdsFGyDYlcSQ= added successfully +INFO: 2025/02/22 23:28:58 Peer NMrcorGgTTi4tAUZ1lLru0qISNrt9D9JdsFGyDYlcSQ= removed successfully +INFO: 2025/02/22 23:28:58 Peer n8ZKTG8vsROL/OiqHYJELU/Rg9XDifz0YjE/lQsL0m0= added successfully +INFO: 2025/02/22 23:33:59 Peer n8ZKTG8vsROL/OiqHYJELU/Rg9XDifz0YjE/lQsL0m0= removed successfully +INFO: 2025/02/22 23:33:59 Peer /i8YTgrLkZh08HKXLXqNFQJsyg1E8I2ELXqF0zuP9D8= added successfully +INFO: 2025/02/22 23:34:06 Peer /i8YTgrLkZh08HKXLXqNFQJsyg1E8I2ELXqF0zuP9D8= removed successfully +INFO: 2025/02/22 23:34:06 Peer 50+RB00sDoSG+KAKzl/baaqPkKGOe7upX7uqRCKqsRo= added successfully +INFO: 2025/02/22 23:35:07 Peer 50+RB00sDoSG+KAKzl/baaqPkKGOe7upX7uqRCKqsRo= removed successfully +INFO: 2025/02/22 23:35:07 Peer Aa2Y2NEmc+SITlT89+fsOeqDkXJVu9RBY14+77TXa3w= added successfully +INFO: 2025/02/23 00:55:55 Peer Aa2Y2NEmc+SITlT89+fsOeqDkXJVu9RBY14+77TXa3w= removed successfully +INFO: 2025/02/23 00:55:55 Peer 2AXNjMQzT7GGvdbIG6MJVFpO3FIzQ+qCqZkdSnBA3DE= added successfully diff --git a/wg/wg.go b/wg/wg.go index 6883ca9..fa6760f 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -5,10 +5,13 @@ import ( "fmt" "net" "os" + "strconv" + "strings" "sync" "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/network" "github.com/fosrl/newt/websocket" "github.com/vishvananda/netlink" "golang.org/x/exp/rand" @@ -214,7 +217,7 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { s.config = config // stop the holepunch - close(s.stopHolepunch) + // close(s.stopHolepunch) // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { @@ -373,94 +376,6 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { return nil } -// func (s *WireGuardService) ensureMSSClamping() error { -// // Calculate MSS value (MTU - 40 for IPv4 header (20) and TCP header (20)) -// mssValue := mtuInt - 40 - -// // Rules to be managed - just the chains, we'll construct the full command separately -// chains := []string{"INPUT", "OUTPUT", "FORWARD"} - -// // First, try to delete any existing rules -// for _, chain := range chains { -// deleteCmd := exec.Command("/usr/sbin/iptables", -// "-t", "mangle", -// "-D", chain, -// "-p", "tcp", -// "--tcp-flags", "SYN,RST", "SYN", -// "-j", "TCPMSS", -// "--set-mss", fmt.Sprintf("%d", mssValue)) - -// logger.Info("Attempting to delete existing MSS clamping rule for chain %s", chain) - -// // Try deletion multiple times to handle multiple existing rules -// for i := 0; i < 3; i++ { -// out, err := deleteCmd.CombinedOutput() -// if err != nil { -// // Convert exit status 1 to string for better logging -// if exitErr, ok := err.(*exec.ExitError); ok { -// logger.Debug("Deletion stopped for chain %s: %v (output: %s)", -// chain, exitErr.String(), string(out)) -// } -// break // No more rules to delete -// } -// logger.Info("Deleted MSS clamping rule for chain %s (attempt %d)", chain, i+1) -// } -// } - -// // Then add the new rules -// var errors []error -// for _, chain := range chains { -// addCmd := exec.Command("/usr/sbin/iptables", -// "-t", "mangle", -// "-A", chain, -// "-p", "tcp", -// "--tcp-flags", "SYN,RST", "SYN", -// "-j", "TCPMSS", -// "--set-mss", fmt.Sprintf("%d", mssValue)) - -// logger.Info("Adding MSS clamping rule for chain %s", chain) - -// if out, err := addCmd.CombinedOutput(); err != nil { -// errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", -// chain, err, string(out)) -// logger.Error(errMsg) -// errors = append(errors, fmt.Errorf(errMsg)) -// continue -// } - -// // Verify the rule was added -// checkCmd := exec.Command("/usr/sbin/iptables", -// "-t", "mangle", -// "-C", chain, -// "-p", "tcp", -// "--tcp-flags", "SYN,RST", "SYN", -// "-j", "TCPMSS", -// "--set-mss", fmt.Sprintf("%d", mssValue)) - -// if out, err := checkCmd.CombinedOutput(); err != nil { -// errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", -// chain, err, string(out)) -// logger.Error(errMsg) -// errors = append(errors, fmt.Errorf(errMsg)) -// continue -// } - -// logger.Info("Successfully added and verified MSS clamping rule for chain %s", chain) -// } - -// // If we encountered any errors, return them combined -// if len(errors) > 0 { -// var errMsgs []string -// for _, err := range errors { -// errMsgs = append(errMsgs, err.Error()) -// } -// return fmt.Errorf("MSS clamping setup encountered errors:\n%s", -// strings.Join(errMsgs, "\n")) -// } - -// return nil -// } - func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { var peer Peer @@ -681,40 +596,72 @@ func (s *WireGuardService) reportPeerBandwidth() error { } func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { - // Bind to specific local port - localAddr := &net.UDPAddr{ - Port: int(s.port), - IP: net.IPv4zero, + // Parse server address + serverSplit := strings.Split(serverAddr, ":") + if len(serverSplit) < 2 { + return fmt.Errorf("invalid server address format, expected hostname:port") } - remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + serverHostname := serverSplit[0] + serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) if err != nil { - return fmt.Errorf("failed to resolve UDP address: %v", err) + return fmt.Errorf("failed to parse server port: %v", err) } - conn, err := net.ListenUDP("udp", localAddr) - if err != nil { - return fmt.Errorf("failed to bind UDP socket: %v", err) + // Resolve server hostname to IP + serverIPAddr := network.HostToAddr(serverHostname) + if serverIPAddr == nil { + return fmt.Errorf("failed to resolve server hostname") } - defer conn.Close() + // Get client IP based on route to server + clientIP := network.GetClientIP(serverIPAddr.IP) + + // Create server and client configs + server := &network.Server{ + Hostname: serverHostname, + Addr: serverIPAddr, + Port: uint16(serverPort), + } + + client := &network.PeerNet{ + IP: clientIP, + Port: s.port, + NewtID: s.newtId, + } + + // Setup raw connection with BPF filtering + rawConn := network.SetupRawConn(server, client) + defer rawConn.Close() + + // Create JSON payload payload := struct { NewtID string `json:"newtId"` }{ NewtID: s.newtId, } - data, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - _, err = conn.WriteToUDP(data, remoteAddr) + // Send the packet using the raw connection + err = network.SendDataPacket(payload, rawConn, server, client) if err != nil { return fmt.Errorf("failed to send UDP packet: %v", err) } - logger.Info("Sent UDP hole punch to %s", serverAddr) + // logger.Info("Sent UDP hole punch to %s", serverAddr) + + // // Wait for response if needed + // response, err := network.RecvDataPacket(rawConn, server, client) + // if err != nil { + // if err, ok := err.(net.Error); ok && err.Timeout() { + // return fmt.Errorf("connection to %s timed out", serverAddr) + // } + // return fmt.Errorf("error receiving response: %v", err) + // } + + // // Process response if needed + // if len(response) > 0 { + // logger.Info("Received response from server") + // } return nil } From b68502de9e81d754c57eae3d88fe128a93112026 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 23 Feb 2025 16:49:24 -0500 Subject: [PATCH 16/19] Basic relay working! --- .gitignore | 3 ++- main.go | 14 ++++++----- nohup.out | 25 ------------------ wg/wg.go | 74 +++++++++++++++++++++++++----------------------------- 4 files changed, 44 insertions(+), 72 deletions(-) delete mode 100644 nohup.out diff --git a/.gitignore b/.gitignore index 8b1c477..100fc81 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ newt .DS_Store -bin/ \ No newline at end of file +bin/ +nohup.out \ No newline at end of file diff --git a/main.go b/main.go index e08eefd..da80c48 100644 --- a/main.go +++ b/main.go @@ -258,7 +258,6 @@ func main() { logLevel string interfaceName string generateAndSaveKeyTo string - reachableAt string ) // if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values @@ -270,7 +269,6 @@ func main() { logLevel = os.Getenv("LOG_LEVEL") interfaceName = os.Getenv("INTERFACE") generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") - reachableAt = os.Getenv("REACHABLE_AT") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -296,9 +294,6 @@ func main() { if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") } - if reachableAt == "" { - flag.StringVar(&reachableAt, "reachableAt", "", "Endpoint of the http server to tell remote config about") - } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -353,7 +348,7 @@ func main() { } // Create WireGuard service - wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, reachableAt, generateAndSaveKeyTo, host, id, client) + wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } @@ -469,6 +464,13 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) } + // first make sure the wpgService has a port + if wgService != nil { + // add a udp proxy for localost and the wgService port + // TODO: make sure this port is not used in a target + pm.AddTarget("udp", wgData.TunnelIP, int(wgService.Port), fmt.Sprintf("localhost:%d", wgService.Port)) + } + err = pm.Start() if err != nil { logger.Error("Failed to start proxy manager: %v", err) diff --git a/nohup.out b/nohup.out deleted file mode 100644 index 58bc6f3..0000000 --- a/nohup.out +++ /dev/null @@ -1,25 +0,0 @@ -INFO: 2025/02/22 23:25:47 Requesting WireGuard configuration from remote server -INFO: 2025/02/22 23:25:47 Sent registration message -INFO: 2025/02/22 23:25:47 Received message: {newt/wg/receive-config map[ipAddress:100.90.128.1/24 listenPort:51822 peers:[]]} -INFO: 2025/02/22 23:25:47 Created WireGuard interface wg1 -INFO: 2025/02/22 23:25:47 Assigning IP address 100.90.128.1/24 to interface wg1 -INFO: 2025/02/22 23:25:47 WireGuard interface wg1 created and configured -INFO: 2025/02/22 23:25:47 Received registration message -INFO: 2025/02/22 23:25:47 Received: {Type:newt/wg/connect Data:map[endpoint:pangolin.fosrl.io:51820 publicKey:tng9Z/BN32flFjqwwT1yAxN/twFkmgbZA+D9N+YqdjM= serverIP:100.89.128.1 targets:map[tcp:[] udp:[]] tunnelIP:100.89.128.4]} -INFO: 2025/02/22 23:25:47 WireGuard device created. Lets ping the server now... -INFO: 2025/02/22 23:25:47 Ping attempt 1 of 5 -INFO: 2025/02/22 23:25:47 Pinging 100.89.128.1 -INFO: 2025/02/22 23:25:47 Ping latency: 9.00105ms -INFO: 2025/02/22 23:25:47 Starting ping check -INFO: 2025/02/22 23:26:48 Peer P9pacnRfUSfvDibaQTdTk59q27eRpgtbMMmMpkNwKl0= removed successfully -INFO: 2025/02/22 23:26:48 Peer NMrcorGgTTi4tAUZ1lLru0qISNrt9D9JdsFGyDYlcSQ= added successfully -INFO: 2025/02/22 23:28:58 Peer NMrcorGgTTi4tAUZ1lLru0qISNrt9D9JdsFGyDYlcSQ= removed successfully -INFO: 2025/02/22 23:28:58 Peer n8ZKTG8vsROL/OiqHYJELU/Rg9XDifz0YjE/lQsL0m0= added successfully -INFO: 2025/02/22 23:33:59 Peer n8ZKTG8vsROL/OiqHYJELU/Rg9XDifz0YjE/lQsL0m0= removed successfully -INFO: 2025/02/22 23:33:59 Peer /i8YTgrLkZh08HKXLXqNFQJsyg1E8I2ELXqF0zuP9D8= added successfully -INFO: 2025/02/22 23:34:06 Peer /i8YTgrLkZh08HKXLXqNFQJsyg1E8I2ELXqF0zuP9D8= removed successfully -INFO: 2025/02/22 23:34:06 Peer 50+RB00sDoSG+KAKzl/baaqPkKGOe7upX7uqRCKqsRo= added successfully -INFO: 2025/02/22 23:35:07 Peer 50+RB00sDoSG+KAKzl/baaqPkKGOe7upX7uqRCKqsRo= removed successfully -INFO: 2025/02/22 23:35:07 Peer Aa2Y2NEmc+SITlT89+fsOeqDkXJVu9RBY14+77TXa3w= added successfully -INFO: 2025/02/23 00:55:55 Peer Aa2Y2NEmc+SITlT89+fsOeqDkXJVu9RBY14+77TXa3w= removed successfully -INFO: 2025/02/23 00:55:55 Peer 2AXNjMQzT7GGvdbIG6MJVFpO3FIzQ+qCqZkdSnBA3DE= added successfully diff --git a/wg/wg.go b/wg/wg.go index fa6760f..f77bdcd 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -51,12 +51,12 @@ type WireGuardService struct { wgClient *wgctrl.Client config WgConfig key wgtypes.Key - reachableAt string newtId string lastReadings map[string]PeerReading mu sync.Mutex - port uint16 + Port uint16 stopHolepunch chan struct{} + host string } // Add this type definition @@ -113,7 +113,7 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) } -func NewWireGuardService(interfaceName string, mtu int, reachableAt string, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { +func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client) (*WireGuardService, error) { wgClient, err := wgctrl.New() if err != nil { return nil, fmt.Errorf("failed to create WireGuard client: %v", err) @@ -155,20 +155,13 @@ func NewWireGuardService(interfaceName string, mtu int, reachableAt string, gene client: wsClient, wgClient: wgClient, key: key, - reachableAt: reachableAt, newtId: newtId, lastReadings: make(map[string]PeerReading), - port: port, + Port: port, stopHolepunch: make(chan struct{}), + host: host, } - if err := service.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) - } - - // start the UDP holepunch - go service.keepSendingUDPHolePunch(host) - // Register websocket handlers wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) @@ -185,7 +178,6 @@ func (s *WireGuardService) LoadRemoteConfig() error { err := s.client.SendMessage("newt/wg/get-config", map[string]interface{}{ "publicKey": fmt.Sprintf("%s", s.key.PublicKey().String()), - "endpoint": s.reachableAt, }) if err != nil { logger.Error("Failed to send registration message: %v", err) @@ -216,9 +208,6 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { } s.config = config - // stop the holepunch - // close(s.stopHolepunch) - // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { logger.Error("Failed to ensure WireGuard interface: %v", err) @@ -227,6 +216,13 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { if err := s.ensureWireguardPeers(config.Peers); err != nil { logger.Error("Failed to ensure WireGuard peers: %v", err) } + + if err := s.sendUDPHolePunch(s.host + ":21820"); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + + // start the UDP holepunch + go s.keepSendingUDPHolePunch(s.host) } func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { @@ -245,6 +241,17 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { } } else { logger.Info("WireGuard interface %s already exists\n", s.interfaceName) + + // get the exising wireguard port + device, err := s.wgClient.Device(s.interfaceName) + if err != nil { + return fmt.Errorf("failed to get device: %v", err) + } + + // get the existing port + s.Port = uint16(device.ListenPort) + logger.Info("WireGuard interface %s already exists with port %d\n", s.interfaceName, s.Port) + return nil } @@ -273,7 +280,7 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { } // Use the service's fixed port instead of the config port - *config.ListenPort = int(s.port) + *config.ListenPort = int(s.Port) // Create and configure the WireGuard interface err = s.wgClient.ConfigureDevice(s.interfaceName, config) @@ -390,6 +397,7 @@ func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { err = s.addPeer(peer) if err != nil { + logger.Info("Error adding peer: %v", err) return } } @@ -411,16 +419,18 @@ func (s *WireGuardService) addPeer(peer Peer) error { } // add keep alive using *time.Duration of 1 second keepalive := time.Second - endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) - if err != nil { - return fmt.Errorf("failed to resolve endpoint address: %w", err) - } + // endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) + // if err != nil { + // return fmt.Errorf("failed to resolve endpoint address: %w", err) + // } + + // make the endpoint localhost to test peerConfig := wgtypes.PeerConfig{ PublicKey: pubKey, AllowedIPs: allowedIPs, PersistentKeepaliveInterval: &keepalive, - Endpoint: endpoint, + // Endpoint: endpoint, } config := wgtypes.Config{ @@ -626,7 +636,7 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { client := &network.PeerNet{ IP: clientIP, - Port: s.port, + Port: s.Port, NewtID: s.newtId, } @@ -647,27 +657,11 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return fmt.Errorf("failed to send UDP packet: %v", err) } - // logger.Info("Sent UDP hole punch to %s", serverAddr) - - // // Wait for response if needed - // response, err := network.RecvDataPacket(rawConn, server, client) - // if err != nil { - // if err, ok := err.(net.Error); ok && err.Timeout() { - // return fmt.Errorf("connection to %s timed out", serverAddr) - // } - // return fmt.Errorf("error receiving response: %v", err) - // } - - // // Process response if needed - // if len(response) > 0 { - // logger.Info("Received response from server") - // } - return nil } func (s *WireGuardService) keepSendingUDPHolePunch(host string) { - ticker := time.NewTicker(1 * time.Second) + ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { From cd3ec0b259e8635015303b88ddadb745ab481521 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 23 Feb 2025 20:18:25 -0500 Subject: [PATCH 17/19] Support relay switch --- wg/wg.go | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/wg/wg.go b/wg/wg.go index f77bdcd..bcb7cda 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -419,18 +419,29 @@ func (s *WireGuardService) addPeer(peer Peer) error { } // add keep alive using *time.Duration of 1 second keepalive := time.Second - // endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) - // if err != nil { - // return fmt.Errorf("failed to resolve endpoint address: %w", err) - // } - // make the endpoint localhost to test + var peerConfig wgtypes.PeerConfig + if peer.Endpoint != "" { + endpoint, err := net.ResolveUDPAddr("udp", peer.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint address: %w", err) + } - peerConfig := wgtypes.PeerConfig{ - PublicKey: pubKey, - AllowedIPs: allowedIPs, - PersistentKeepaliveInterval: &keepalive, - // Endpoint: endpoint, + // make the endpoint localhost to test + + peerConfig = wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + Endpoint: endpoint, + } + } else { + peerConfig = wgtypes.PeerConfig{ + PublicKey: pubKey, + AllowedIPs: allowedIPs, + PersistentKeepaliveInterval: &keepalive, + } + logger.Info("Added peer with no endpoint!") } config := wgtypes.Config{ From 5e673c829b1d88a7350f851445c1fd593980b4d1 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 24 Feb 2025 10:05:35 -0500 Subject: [PATCH 18/19] Clean up when wg is used --- main.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index da80c48..7618d3b 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" "os/signal" + "runtime" "strconv" "strings" "syscall" @@ -340,6 +341,12 @@ func main() { var wgData WgData if generateAndSaveKeyTo != "" { + // make sure we are running on linux + if runtime.GOOS != "linux" { + logger.Fatal("Tunnel management is only supported on Linux right now!") + os.Exit(1) + } + var host = endpoint if strings.HasPrefix(host, "http://") { host = strings.TrimPrefix(host, "http://") @@ -569,7 +576,9 @@ persistent_keepalive_interval=5`, fixKey(fmt.Sprintf("%s", privateKey)), fixKey( return err } - wgService.LoadRemoteConfig() + if wgService != nil { + wgService.LoadRemoteConfig() + } logger.Info("Sent registration message") return nil From 067e07929353c70eb4c43a3c596fe52a7ddf7ccd Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 12 Mar 2025 20:37:57 -0400 Subject: [PATCH 19/19] Handle / better --- main.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/main.go b/main.go index 7618d3b..f585222 100644 --- a/main.go +++ b/main.go @@ -214,6 +214,9 @@ func resolveDomain(domain string) (string, error) { host = strings.TrimPrefix(host, "https://") } + // if there are any trailing slashes, remove them + host = strings.TrimSuffix(host, "/") + // Lookup IP addresses ips, err := net.LookupIP(host) if err != nil { @@ -354,6 +357,8 @@ func main() { host = strings.TrimPrefix(host, "https://") } + host = strings.TrimSuffix(host, "/") + // Create WireGuard service wgService, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) if err != nil {