diff --git a/clients.go b/clients.go new file mode 100644 index 0000000..78f7844 --- /dev/null +++ b/clients.go @@ -0,0 +1,125 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/proxy" + "github.com/fosrl/newt/websocket" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/fosrl/newt/wgnetstack" + "github.com/fosrl/newt/wgtester" +) + +var wgService *wgnetstack.WireGuardService +var wgTesterServer *wgtester.Server +var ready bool + +func setupClients(client *websocket.Client) { + var host = endpoint + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + + host = strings.TrimSuffix(host, "/") + + if useNativeInterface { + setupClientsNative(client, host) + } else { + setupClientsNetstack(client, host) + } + + ready = true +} + +func setupClientsNetstack(client *websocket.Client, host string) { + logger.Info("Setting up clients with netstack...") + // Create WireGuard service + wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "8.8.8.8") + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } + + // // Set up callback to restart wgtester with netstack when WireGuard is ready + wgService.SetOnNetstackReady(func(tnet *netstack.Net) { + + wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", wgService.Port, id, tnet) // TODO: maybe make this the same ip of the wg server? + err := wgTesterServer.Start() + if err != nil { + logger.Error("Failed to start WireGuard tester server: %v", err) + } + }) + + client.OnTokenUpdate(func(token string) { + wgService.SetToken(token) + }) +} + +func setDownstreamTNetstack(tnet *netstack.Net) { + if wgService != nil { + wgService.SetOthertnet(tnet) + } +} + +func closeClients() { + logger.Info("Closing clients...") + if wgService != nil { + wgService.Close(!keepInterface) + wgService = nil + } + + closeWgServiceNative() + + if wgTesterServer != nil { + wgTesterServer.Stop() + wgTesterServer = nil + } +} + +func clientsHandleNewtConnection(publicKey string, endpoint string) { + if !ready { + return + } + + // split off the port from the endpoint + parts := strings.Split(endpoint, ":") + if len(parts) < 2 { + logger.Error("Invalid endpoint format: %s", endpoint) + return + } + endpoint = strings.Join(parts[:len(parts)-1], ":") + + if wgService != nil { + wgService.StartHolepunch(publicKey, endpoint) + } + + clientsHandleNewtConnectionNative(publicKey, endpoint) +} + +func clientsOnConnect() { + if !ready { + return + } + if wgService != nil { + wgService.LoadRemoteConfig() + } + + clientsOnConnectNative() +} + +func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { + if !ready { + return + } + // add a udp proxy for localost and the wgService port + // TODO: make sure this port is not used in a target + if wgService != nil { + pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) + } + + clientsAddProxyTargetNative(pm, tunnelIp) +} diff --git a/linux.go b/linux.go index a88765a..f8996e9 100644 --- a/linux.go +++ b/linux.go @@ -4,107 +4,74 @@ package main import ( "fmt" - "strings" + "os" + "runtime" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" - "golang.zx2c4.com/wireguard/tun/netstack" - - "github.com/fosrl/newt/wgnetstack" + "github.com/fosrl/newt/wg" "github.com/fosrl/newt/wgtester" ) -var wgService *wgnetstack.WireGuardService -var wgTesterServer *wgtester.Server +var wgServiceNative *wg.WireGuardService -func setupClients(client *websocket.Client) { - var host = endpoint - if strings.HasPrefix(host, "http://") { - host = strings.TrimPrefix(host, "http://") - } else if strings.HasPrefix(host, "https://") { - host = strings.TrimPrefix(host, "https://") +func setupClientsNative(client *websocket.Client, host string) { + + if runtime.GOOS != "linux" { + logger.Fatal("Tunnel management is only supported on Linux right now!") + os.Exit(1) } - host = strings.TrimSuffix(host, "/") + // make sure we are sudo + if os.Geteuid() != 0 { + logger.Fatal("You must run this program as root to manage tunnels on Linux.") + os.Exit(1) + } - if useNativeInterface { + // Create WireGuard service + wgServiceNative, err = wg.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client) + if err != nil { + logger.Fatal("Failed to create WireGuard service: %v", err) + } - } else { - // Create WireGuard service - wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "8.8.8.8") - if err != nil { - logger.Fatal("Failed to create WireGuard service: %v", err) - } - - // // Set up callback to restart wgtester with netstack when WireGuard is ready - wgService.SetOnNetstackReady(func(tnet *netstack.Net) { - - wgTesterServer = wgtester.NewServerWithNetstack("0.0.0.0", wgService.Port, id, tnet) // TODO: maybe make this the same ip of the wg server? - err := wgTesterServer.Start() - if err != nil { - logger.Error("Failed to start WireGuard tester server: %v", err) - } - // logger.Info("WireGuard netstack is ready, restarting wgtester with netstack") - // if err := wgTesterServer.RestartWithNetstack(tnet); err != nil { - // logger.Error("Failed to restart wgtester with netstack: %v", err) - // } else { - // logger.Info("WGTester successfully restarted with netstack") - // } - }) + wgTesterServer = wgtester.NewServer("0.0.0.0", wgServiceNative.Port, id) // TODO: maybe make this the same ip of the wg server? + err := wgTesterServer.Start() + if err != nil { + logger.Error("Failed to start WireGuard tester server: %v", err) } client.OnTokenUpdate(func(token string) { - wgService.SetToken(token) + wgServiceNative.SetToken(token) }) } -func setDownstreamTNetstack(tnet *netstack.Net) { - if wgService != nil { - wgService.SetOthertnet(tnet) +func closeWgServiceNative() { + if wgServiceNative != nil { + wgServiceNative.Close(!keepInterface) + wgServiceNative = nil } } -func closeClients() { - if wgService != nil { - wgService.Close(!keepInterface) - wgService = nil - } - - if wgTesterServer != nil { - wgTesterServer.Stop() - wgTesterServer = nil +func clientsOnConnectNative() { + if wgServiceNative != nil { + wgServiceNative.LoadRemoteConfig() } } -func clientsHandleNewtConnection(publicKey string, endpoint string) { - if wgService == nil { - return +func clientsHandleNewtConnectionNative(publicKey, endpoint string) { + if wgServiceNative != nil { + wgServiceNative.StartHolepunch(publicKey, endpoint) } - - // split off the port from the endpoint - parts := strings.Split(endpoint, ":") - if len(parts) < 2 { - logger.Error("Invalid endpoint format: %s", endpoint) - return - } - endpoint = strings.Join(parts[:len(parts)-1], ":") - - wgService.StartHolepunch(publicKey, endpoint) } -func clientsOnConnect() { - if wgService == nil { - return - } - wgService.LoadRemoteConfig() -} - -func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { - if wgService == nil { +func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) { + if !ready { return } // add a udp proxy for localost and the wgService port // TODO: make sure this port is not used in a target - pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) + if wgServiceNative != nil { + pm.AddTarget("udp", tunnelIp, int(wgServiceNative.Port), fmt.Sprintf("127.0.0.1:%d", wgServiceNative.Port)) + } } diff --git a/main.go b/main.go index b0d2b74..883f6d7 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "net/netip" "os" "os/signal" - "runtime" "strconv" "strings" "syscall" @@ -141,7 +140,7 @@ func main() { flag.StringVar(&interfaceName, "interface", "newt", "Name of the WireGuard interface") } if generateAndSaveKeyTo == "" { - flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "/tmp/newtkey", "Path to save generated private key") + flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") } flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface") flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux") @@ -269,12 +268,6 @@ func main() { var wgData WgData if acceptClients { - // make sure we are running on linux - if runtime.GOOS != "linux" { - logger.Fatal("Tunnel management is only supported on Linux right now!") - os.Exit(1) - } - setupClients(client) } diff --git a/stub.go b/stub.go index e2360ff..ec91299 100644 --- a/stub.go +++ b/stub.go @@ -7,26 +7,26 @@ import ( "github.com/fosrl/newt/websocket" ) -func setupClients(client *websocket.Client) { +func setupClientsNative(client *websocket.Client, host string) { return // This function is not implemented for non-Linux systems. } -func closeClients() { - // This function is not implemented for non-Linux systems. +func closeWgServiceNative() { + // No-op for non-Linux systems return } -func clientsHandleNewtConnection(publicKey string) { - // This function is not implemented for non-Linux systems. +func clientsOnConnectNative() { + // No-op for non-Linux systems return } -func clientsOnConnect() { - // This function is not implemented for non-Linux systems. +func clientsHandleNewtConnectionNative(publicKey, endpoint string) { + // No-op for non-Linux systems return } -func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { - // This function is not implemented for non-Linux systems. +func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) { + // No-op for non-Linux systems return } diff --git a/wg/wg.go b/wg/wg.go index 364b3c8..13aab27 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -152,25 +152,27 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str var key wgtypes.Key // if generateAndSaveKeyTo is provided, generate a private key and save it to the file. if the file already exists, load the key from the file - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - // generate a new private key - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - // save the key to the file - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) - if err != nil { - logger.Fatal("Failed to save private key: %v", err) - } - } else { - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - logger.Fatal("Failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(string(keyData)) - if err != nil { - logger.Fatal("Failed to parse private key: %v", err) + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } + + // Load or generate private key + if generateAndSaveKeyTo != "" { + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + return nil, fmt.Errorf("failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData))) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + } else { + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + return nil, fmt.Errorf("failed to save private key: %v", err) + } } } diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 73b947d..3401e42 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -166,21 +166,7 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { continue // Port is in use or there was an error, try next port } - // Check if port+1 is also available - addr2 := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: int(port + 1), - } - conn2, err2 := net.ListenUDP("udp", addr2) - if err2 != nil { - // The next port is not available, so close the first connection and try again - conn1.Close() - continue - } - - // Both ports are available, close connections and return the first port conn1.Close() - conn2.Close() return port, nil } @@ -189,27 +175,29 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) { var key wgtypes.Key + var err error + + key, err = wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate private key: %v", err) + } // Load or generate private key - if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { - // Generate a new private key - key, err = wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate private key: %v", err) - } - // Save the key to the file - err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) - if err != nil { - return nil, fmt.Errorf("failed to save private key: %v", err) - } - } else { - keyData, err := os.ReadFile(generateAndSaveKeyTo) - if err != nil { - return nil, fmt.Errorf("failed to read private key: %v", err) - } - key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData))) - if err != nil { - return nil, fmt.Errorf("failed to parse private key: %v", err) + if generateAndSaveKeyTo != "" { + if _, err := os.Stat(generateAndSaveKeyTo); os.IsNotExist(err) { + keyData, err := os.ReadFile(generateAndSaveKeyTo) + if err != nil { + return nil, fmt.Errorf("failed to read private key: %v", err) + } + key, err = wgtypes.ParseKey(strings.TrimSpace(string(keyData))) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + } else { + err = os.WriteFile(generateAndSaveKeyTo, []byte(key.String()), 0644) + if err != nil { + return nil, fmt.Errorf("failed to save private key: %v", err) + } } } @@ -632,6 +620,11 @@ func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) { return } + if s.device == nil { + logger.Info("WireGuard device is not initialized") + return + } + err = s.addPeerToDevice(peer) if err != nil { logger.Info("Error adding peer: %v", err) @@ -658,6 +651,11 @@ func (s *WireGuardService) handleRemovePeer(msg websocket.WSMessage) { return } + if s.device == nil { + logger.Info("WireGuard device is not initialized") + return + } + if err := s.removePeer(request.PublicKey); err != nil { logger.Info("Error removing peer: %v", err) return @@ -711,6 +709,11 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) { return } + if s.device == nil { + logger.Info("WireGuard device is not initialized") + return + } + // Build IPC configuration string to update the peer config := fmt.Sprintf("public_key=%s\nupdate_only=true", fixKey(pubKey.String())) @@ -935,12 +938,9 @@ func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { return fmt.Errorf("failed to resolve server hostname") } - // Get client IP based on route to server for local binding - clientIP := network.GetClientIP(serverIPAddr.IP) - // Create local UDP address using the same port as WireGuard localAddr := &net.UDPAddr{ - IP: clientIP, + IP: net.IPv4zero, Port: int(s.Port), }