diff --git a/main.go b/main.go index ad41fd1..59fb7c6 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "fmt" "math/rand" "net" + "net/http" "net/netip" "os" "os/exec" @@ -52,6 +53,13 @@ type TargetData struct { Targets []string `json:"targets"` } +// ExitNode represents an exit node with an ID, endpoint, and weight. +type ExitNode struct { + ID string `json:"id"` + Endpoint string `json:"endpoint"` + Weight float64 `json:"weight"` +} + func fixKey(key string) string { // Remove any whitespace key = strings.TrimSpace(key) @@ -363,6 +371,7 @@ var ( updownScript string tlsPrivateKey string dockerSocket string + publicKey wgtypes.Key ) func main() { @@ -623,6 +632,88 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } }) + client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { + logger.Info("Received ping message") + + // Parse the incoming list of exit nodes + var exitNodes []ExitNode + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + if err := json.Unmarshal(jsonData, &exitNodes); err != nil { + logger.Info("Error unmarshaling exit node data: %v", err) + return + } + if len(exitNodes) == 0 { + logger.Info("No exit nodes provided") + return + } + + type nodeResult struct { + Node ExitNode + Latency time.Duration + Err error + } + + results := make([]nodeResult, len(exitNodes)) + for i, node := range exitNodes { + start := time.Now() + client := &http.Client{ + Timeout: 5 * time.Second, + } + url := node.Endpoint + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = "http://" + url + } + if !strings.HasSuffix(url, "/ping") { + url = strings.TrimRight(url, "/") + "/ping" + } + resp, err := client.Get(url) + latency := time.Since(start) + if err != nil { + logger.Warn("Failed to ping exit node %s (%s): %v", node.ID, url, err) + results[i] = nodeResult{Node: node, Latency: latency, Err: err} + continue + } + resp.Body.Close() + results[i] = nodeResult{Node: node, Latency: latency, Err: nil} + logger.Info("Exit node %s latency: %v", node.ID, latency) + } + + // Select the best node based on weighted score (latency * (1/weight)) + var bestNode *ExitNode + var bestScore float64 = 1e12 // large initial value + for _, res := range results { + if res.Err != nil || res.Node.Weight <= 0 { + continue + } + score := float64(res.Latency.Milliseconds()) / res.Node.Weight + logger.Info("Exit node %s score: %.2f (latency: %dms, weight: %.2f)", res.Node.ID, score, res.Latency.Milliseconds(), res.Node.Weight) + if score < bestScore { + bestScore = score + bestNode = &res.Node + } + } + + if bestNode == nil { + logger.Warn("No suitable exit node found") + return + } + + logger.Info("Selected exit node: %s (%s)", bestNode.ID, bestNode.Endpoint) + + err = client.SendMessage("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "exitNode": bestNode.ID, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return + } + }) + client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { logger.Info("Received: %+v", msg) @@ -762,15 +853,13 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub }) client.OnConnect(func() error { - publicKey := privateKey.PublicKey() + publicKey = privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) - err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - }) + // request from the server the list of nodes to ping at newt/ping/request + err := client.SendMessage("newt/ping/request", map[string]interface{}{}) if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err + logger.Error("Failed to send ping request: %v", err) } if wgService != nil {