diff --git a/main.go b/main.go index a522df1..3ec52ca 100644 --- a/main.go +++ b/main.go @@ -4,12 +4,14 @@ import ( "bytes" "encoding/base64" "encoding/hex" + "encoding/json" "flag" "fmt" "log" "math/rand" "net/netip" "newt/proxy" + "newt/websocket" "os" "os/signal" "strings" @@ -20,9 +22,18 @@ import ( "golang.org/x/net/ipv4" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/netstack" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +type WgData struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + TunnelIP string `json:"tunnelIP"` +} + func fixKey(key string) string { // Remove any whitespace key = strings.TrimSpace(key) @@ -73,117 +84,232 @@ func ping(tnet *netstack.Net, dst string) { func main() { var ( - tunnelIP string - privateKey string - publicKey string - endpoint string - tcpTargets string - udpTargets string - listenIP string - serverIP string dns string + id string + secret string + privateKey wgtypes.Key + err error ) - flag.StringVar(&tunnelIP, "tunnel-ip", "", "Tunnel IP address") - flag.StringVar(&privateKey, "private-key", "", "WireGuard private key") - flag.StringVar(&publicKey, "public-key", "", "WireGuard public key") - flag.StringVar(&endpoint, "endpoint", "", "WireGuard endpoint (host:port)") - flag.StringVar(&tcpTargets, "tcp-targets", "", "Comma-separated list of TCP targets (host:port)") - flag.StringVar(&udpTargets, "udp-targets", "", "Comma-separated list of UDP targets (host:port)") - flag.StringVar(&listenIP, "listen-ip", "", "IP to listen for incoming connections") - flag.StringVar(&serverIP, "server-ip", "", "IP to filter and ping on the server side. Inside tunnel...") flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") + flag.StringVar(&id, "id", "", "Newt ID") + flag.StringVar(&secret, "secret", "", "Newt secret") flag.Parse() - // Create TUN device and network stack - tun, tnet, err := netstack.CreateNetTUN( - []netip.Addr{netip.MustParseAddr(tunnelIP)}, - []netip.Addr{netip.MustParseAddr(dns)}, - 1420) + privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { - log.Panic(err) + log.Fatalf("Failed to generate private key: %v", err) } - // Create WireGuard device - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + // Create a new client + client, err := websocket.NewClient( + // the id and secret from the params + id, + secret, + websocket.WithBaseURL("http://localhost:3000/api/v1"), + ) + if err != nil { + log.Fatal(err) + } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s + // Create TUN device and network stack + var tun tun.Device + var tnet *netstack.Net + var dev *device.Device + var pm *proxy.ProxyManager + var connected bool + var wgData WgData + + // Register handlers for different message types + client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { + if connected { + log.Printf("Already connected! Put I will send a ping anyway...") + ping(tnet, wgData.ServerIP) + return + } + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + log.Printf("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &wgData); err != nil { + log.Printf("Error unmarshaling target data: %v", err) + return + } + + log.Printf("Received: %+v", msg) + tun, tnet, err = netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr(wgData.TunnelIP)}, + []netip.Addr{netip.MustParseAddr(dns)}, + 1420) + if err != nil { + log.Panic(err) + } + + // Create WireGuard device + dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + + // Configure WireGuard + config := fmt.Sprintf(`private_key=%s public_key=%s allowed_ip=%s/32 endpoint=%s -persistent_keepalive_interval=5 -`, fixKey(privateKey), fixKey(publicKey), serverIP, endpoint) +persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.PublicKey), wgData.ServerIP, wgData.Endpoint) - err = dev.IpcSet(config) - if err != nil { - log.Panic(err) - } - - // Bring up the device - err = dev.Up() - if err != nil { - log.Panic(err) - } - - // Ping to bring the tunnel up on the server side quickly - ping(tnet, serverIP) - - // Create proxy manager - pm := proxy.NewProxyManager(tnet) - - // Add TCP targets - if tcpTargets != "" { - targets := strings.Split(tcpTargets, ",") - for _, t := range targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 2 { - log.Panicf("Invalid target: %s", t) - } - // get the port as a int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - log.Panicf("Invalid port: %s", parts[0]) - } - target := parts[1] - pm.AddTarget("tcp", listenIP, port, target) + err = dev.IpcSet(config) + if err != nil { + log.Panic(err) } - } - // Add UDP targets - if udpTargets != "" { - targets := strings.Split(udpTargets, ",") - for _, t := range targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 2 { - log.Panicf("Invalid target: %s", t) - } - // get the port as a int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - log.Panicf("Invalid port: %s", parts[0]) - } - target := parts[1] - pm.AddTarget("udp", listenIP, port, target) + // Bring up the device + err = dev.Up() + if err != nil { + log.Panic(err) } - } - // Start proxies - err = pm.Start() + // Ping to bring the tunnel up on the server side quickly + ping(tnet, wgData.ServerIP) + + // Create proxy manager + pm = proxy.NewProxyManager(tnet) + + connected = true + }) + + client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { + log.Printf("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + log.Printf("No tunnel IP or proxy manager available") + return + } + + type TargetData struct { + Targets []string `json:"targets"` + } + // Define a struct for the expected data structure + jsonData, err := json.Marshal(msg.Data) + if err != nil { + log.Printf("Error marshaling data: %v", err) + return + } + + // Parse into our target structure + var targetData TargetData + if err := json.Unmarshal(jsonData, &targetData); err != nil { + log.Printf("Error unmarshaling target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + + // Stop the proxy manager before adding new targets + err = pm.Stop() + if err != nil { + log.Panic(err) + } + + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 2 { + log.Printf("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + log.Printf("Invalid port: %s", parts[0]) + continue + } + + target := parts[1] + pm.AddTarget("tcp", wgData.TunnelIP, port, target) + } + + err = pm.Start() + if err != nil { + log.Panic(err) + } + } + }) + + client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) { + log.Printf("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + log.Printf("No tunnel IP or proxy manager available") + return + } + + type TargetData struct { + Targets []string `json:"targets"` + } + jsonData, err := json.Marshal(msg.Data) + if err != nil { + log.Printf("Error marshaling data: %v", err) + return + } + + var targetData TargetData + if err := json.Unmarshal(jsonData, &targetData); err != nil { + log.Printf("Error unmarshaling target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + err = pm.Stop() + if err != nil { + log.Panic(err) + } + + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 2 { + log.Printf("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + log.Printf("Invalid port: %s", parts[0]) + continue + } + + target := parts[1] + pm.AddTarget("udp", wgData.TunnelIP, port, target) + } + + err = pm.Start() + if err != nil { + log.Panic(err) + } + } + }) + + // Connect to the WebSocket server + if err := client.Connect(); err != nil { + log.Fatal(err) + } + defer client.Close() + + // TODO: we need to send the public key to the server to trigger it to respond to create the tunnel + // TODO: how to retry? + err = client.SendMessage("newt/wg/register", map[string]interface{}{ + "content": "Hello, World!", + }) if err != nil { - log.Panic(err) - } - - url := "ws://localhost/api/v1/ws" - token := "your-auth-token" - - if err := websocket.connectWebSocket(url, token); err != nil { - log.Fatalf("WebSocket error: %v", err) + log.Printf("Failed to send message: %v", err) } // Wait for interrupt signal diff --git a/proxy/manager.go b/proxy/manager.go index 65298e6..45d667a 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -78,6 +78,25 @@ func (pm *ProxyManager) Start() error { return nil } +func (pm *ProxyManager) Stop() error { + pm.Lock() + defer pm.Unlock() + + for i := range pm.targets { + target := &pm.targets[i] + close(target.cancel) + target.Lock() + if target.listener != nil { + target.listener.Close() + } + if target.udpConn != nil { + target.udpConn.Close() + } + target.Unlock() + } + return nil +} + func (pm *ProxyManager) serveTCP(target *ProxyTarget) { listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ IP: net.ParseIP(target.Listen), diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 0000000..34c0665 --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,202 @@ +package websocket + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + "sync" + + "github.com/gorilla/websocket" +) + +type Client struct { + conn *websocket.Conn + config *Config + baseURL string + handlers map[string]MessageHandler + done chan struct{} + handlersMux sync.RWMutex +} + +type ClientOption func(*Client) + +type MessageHandler func(message WSMessage) + +// WithBaseURL sets the base URL for the client +func WithBaseURL(url string) ClientOption { + return func(c *Client) { + c.baseURL = url + } +} + +// NewClient creates a new Newt client +func NewClient(newtID, secret string, opts ...ClientOption) (*Client, error) { + config := &Config{ + NewtID: newtID, + Secret: secret, + } + + client := &Client{ + config: config, + baseURL: "http://localhost:3000", // default value + handlers: make(map[string]MessageHandler), + done: make(chan struct{}), + } + + // Apply options + for _, opt := range opts { + opt(client) + } + + // Load existing config if available + if err := client.loadConfig(); err != nil { + return nil, fmt.Errorf("failed to load config: %w", err) + } + + return client, nil +} + +// Connect establishes the WebSocket connection +func (c *Client) Connect() error { + // Get token for authentication + token, err := c.getToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + // Update config with new token and save + c.config.Token = token + if err := c.saveConfig(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + + // Connect to WebSocket + wsURL := fmt.Sprintf("ws://%s/ws", c.baseURL[7:]) // Remove http:// prefix + u, err := url.Parse(wsURL) + if err != nil { + return fmt.Errorf("failed to parse WebSocket URL: %w", err) + } + + q := u.Query() + q.Set("token", token) + u.RawQuery = q.Encode() + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return fmt.Errorf("failed to connect to WebSocket: %w", err) + } + + c.conn = conn + go c.readPump() + + return nil +} + +// Close closes the WebSocket connection +func (c *Client) Close() error { + close(c.done) + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +// SendMessage sends a message through the WebSocket connection +func (c *Client) SendMessage(messageType string, data interface{}) error { + if c.conn == nil { + return fmt.Errorf("not connected") + } + + msg := WSMessage{ + Type: messageType, + Data: data, + } + + return c.conn.WriteJSON(msg) +} + +// RegisterHandler registers a handler for a specific message type +func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { + c.handlersMux.Lock() + defer c.handlersMux.Unlock() + c.handlers[messageType] = handler +} + +// readPump pumps messages from the WebSocket connection +func (c *Client) readPump() { + defer c.conn.Close() + + for { + select { + case <-c.done: + return + default: + var msg WSMessage + err := c.conn.ReadJSON(&msg) + if err != nil { + log.Printf("read error: %v", err) + return + } + + c.handlersMux.RLock() + if handler, ok := c.handlers[msg.Type]; ok { + handler(msg) + } + c.handlersMux.RUnlock() + } + } +} + +func (c *Client) getToken() (string, error) { + // If we already have a token, try to use it + if c.config.Token != "" { + tokenCheckData := map[string]interface{}{ + "newtId": c.config.NewtID, + "secret": c.config.Secret, + "token": c.config.Token, + } + jsonData, _ := json.Marshal(tokenCheckData) + + resp, err := http.Post(c.baseURL+"/auth/newt/get-token", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", err + } + + if tokenResp.Success && tokenResp.Message == "Token session already valid" { + return c.config.Token, nil + } + } + + // Get a new token + tokenData := map[string]interface{}{ + "newtId": c.config.NewtID, + "secret": c.config.Secret, + } + jsonData, _ := json.Marshal(tokenData) + + resp, err := http.Post(c.baseURL+"/auth/newt/get-token", "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", err + } + + if !tokenResp.Success { + return "", fmt.Errorf("failed to get token: %s", tokenResp.Message) + } + + return tokenResp.Data.Token, nil +} diff --git a/websocket/config.go b/websocket/config.go new file mode 100644 index 0000000..b47e7be --- /dev/null +++ b/websocket/config.go @@ -0,0 +1,57 @@ +package websocket + +import ( + "encoding/json" + "io/ioutil" + "log" + "os" + "path/filepath" + "runtime" +) + +func getConfigPath() string { + var configDir string + switch runtime.GOOS { + case "darwin": + configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "newt-client") + case "windows": + configDir = filepath.Join(os.Getenv("APPDATA"), "newt-client") + default: // linux and others + configDir = filepath.Join(os.Getenv("HOME"), ".config", "newt-client") + } + + if err := os.MkdirAll(configDir, 0755); err != nil { + log.Printf("Failed to create config directory: %v", err) + } + + return filepath.Join(configDir, "config.json") +} + +func (c *Client) loadConfig() error { + configPath := getConfigPath() + data, err := ioutil.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var config Config + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + // Only update token from saved config + c.config.Token = config.Token + return nil +} + +func (c *Client) saveConfig() error { + configPath := getConfigPath() + data, err := json.MarshalIndent(c.config, "", " ") + if err != nil { + return err + } + return ioutil.WriteFile(configPath, data, 0644) +} diff --git a/websocket/manager.go b/websocket/manager.go deleted file mode 100644 index 9308295..0000000 --- a/websocket/manager.go +++ /dev/null @@ -1,60 +0,0 @@ -package websocket - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - - "github.com/gorilla/websocket" -) - -func connectWebSocket(url, token string) error { - // Create custom header with the auth token - header := http.Header{} - header.Add("Sec-WebSocket-Protocol", token) - - // Create dialer with default options - dialer := websocket.Dialer{ - EnableCompression: true, - } - - // Connect to WebSocket server - conn, resp, err := dialer.Dial(url, header) - if err != nil { - log.Printf("Dial failed: %v", err) - if resp != nil { - log.Printf("HTTP Response Status: %s", resp.Status) - } - return err - } - defer conn.Close() - - log.Printf("Connected to WebSocket server") - - // Message handling loop - for { - // Read message - messageType, message, err := conn.ReadMessage() - if err != nil { - log.Printf("Read error: %v", err) - return err - } - - // Handle text messages (JSON expected) - if messageType == websocket.TextMessage { - // Create a map to store the JSON data - var jsonData map[string]interface{} - - // Unmarshal the JSON message - if err := json.Unmarshal(message, &jsonData); err != nil { - log.Printf("JSON parsing error: %v", err) - // Continue reading messages even if one fails to parse - continue - } - - // Print the parsed JSON message - fmt.Printf("Received message: %+v\n", jsonData) - } - } -} diff --git a/websocket/types.go b/websocket/types.go new file mode 100644 index 0000000..623a0a0 --- /dev/null +++ b/websocket/types.go @@ -0,0 +1,20 @@ +package websocket + +type Config struct { + NewtID string `json:"newtId"` + Secret string `json:"secret"` + Token string `json:"token"` +} + +type TokenResponse struct { + Data struct { + Token string `json:"token"` + } `json:"data"` + Success bool `json:"success"` + Message string `json:"message"` +} + +type WSMessage struct { + Type string `json:"type"` + Data interface{} `json:"data"` +}