diff --git a/clients/clients.go b/clients/clients.go index ccf41aa..b8c7d2d 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -2,6 +2,8 @@ package clients import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "net" @@ -34,6 +36,7 @@ type WgConfig struct { IpAddress string `json:"ipAddress"` Peers []Peer `json:"peers"` Targets []Target `json:"targets"` + ChainId string `json:"chainId"` } type Target struct { @@ -82,7 +85,8 @@ type WireGuardService struct { host string serverPubKey string token string - stopGetConfig func() + stopGetConfig func() + pendingConfigChainId string // Netstack fields tun tun.Device tnet *netstack2.Net @@ -107,6 +111,13 @@ type WireGuardService struct { wgTesterServer *wgtester.Server } +// generateChainId generates a random chain ID for deduplicating round-trip messages. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -441,9 +452,12 @@ func (s *WireGuardService) LoadRemoteConfig() error { s.stopGetConfig() s.stopGetConfig = nil } + chainId := generateChainId() + s.pendingConfigChainId = chainId s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ "publicKey": s.key.PublicKey().String(), "port": s.Port, + "chainId": chainId, }, 2*time.Second) logger.Debug("Requesting WireGuard configuration from remote server") @@ -468,6 +482,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { logger.Info("Error unmarshaling target data: %v", err) return } + + // Deduplicate using chainId: discard responses that don't match the + // pending request, or that we have already processed. + if config.ChainId != "" { + if config.ChainId != s.pendingConfigChainId { + logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId) + return + } + s.pendingConfigChainId = "" // consume – further duplicates are rejected + } + s.config = config if s.stopGetConfig != nil { diff --git a/common.go b/common.go index 4701411..c55909a 100644 --- a/common.go +++ b/common.go @@ -285,11 +285,18 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien if tunnelID != "" { telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout) } - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + pingChainId := generateChainId() + pendingPingChainId = pingChainId + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ + "chainId": pingChainId, + }, 3*time.Second) // Send registration message to the server for backward compatibility + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId err := client.SendMessage("newt/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "backwardsCompatible": true, + "chainId": bcChainId, }) if err != nil { logger.Error("Failed to send registration message: %v", err) diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index 9889cc6..f618803 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -5,7 +5,9 @@ import ( "crypto/tls" "encoding/json" "fmt" + "net" "net/http" + "strconv" "strings" "sync" "time" @@ -365,11 +367,12 @@ func (m *Monitor) performHealthCheck(target *Target) { target.LastCheck = time.Now() target.LastError = "" - // Build URL - url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname) + // Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports) + host := target.Config.Hostname if target.Config.Port > 0 { - url = fmt.Sprintf("%s:%d", url, target.Config.Port) + host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port)) } + url := fmt.Sprintf("%s://%s", target.Config.Scheme, host) if target.Config.Path != "" { if !strings.HasPrefix(target.Config.Path, "/") { url += "/" diff --git a/main.go b/main.go index c9e7d8d..e051450 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,16 @@ package main import ( "bytes" "context" + "crypto/rand" "crypto/tls" + "encoding/hex" "encoding/json" "errors" "flag" "fmt" "net" "net/http" + "net/http/pprof" "net/netip" "os" "os/signal" @@ -45,6 +48,7 @@ type WgData struct { TunnelIP string `json:"tunnelIP"` Targets TargetsByType `json:"targets"` HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` + ChainId string `json:"chainId"` } type TargetsByType struct { @@ -58,6 +62,7 @@ type TargetData struct { type ExitNodeData struct { ExitNodes []ExitNode `json:"exitNodes"` + ChainId string `json:"chainId"` } // ExitNode represents an exit node with an ID, endpoint, and weight. @@ -127,6 +132,8 @@ var ( publicKey wgtypes.Key pingStopChan chan struct{} stopFunc func() + pendingRegisterChainId string + pendingPingChainId string healthFile string useNativeInterface bool authorizedKeysFile string @@ -147,6 +154,7 @@ var ( adminAddr string region string metricsAsyncBytes bool + pprofEnabled bool blueprintFile string noCloud bool @@ -159,6 +167,13 @@ var ( tlsPrivateKey string ) +// generateChainId generates a random chain ID for deduplicating round-trip messages. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func main() { // Check for subcommands first (only principals exits early) if len(os.Args) > 1 { @@ -225,6 +240,7 @@ func runNewtMain(ctx context.Context) { adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR") regionEnv := os.Getenv("NEWT_REGION") asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES") + pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED") disableClientsEnv := os.Getenv("DISABLE_CLIENTS") disableClients = disableClientsEnv == "true" @@ -390,6 +406,14 @@ func runNewtMain(ctx context.Context) { metricsAsyncBytes = v } } + // pprof debug endpoint toggle + if pprofEnabledEnv == "" { + flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server") + } else { + if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil { + pprofEnabled = v + } + } // Optional region flag (resource attribute) if regionEnv == "" { flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)") @@ -485,6 +509,14 @@ func runNewtMain(ctx context.Context) { if tel.PrometheusHandler != nil { mux.Handle("/metrics", tel.PrometheusHandler) } + if pprofEnabled { + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr) + } admin := &http.Server{ Addr: tcfg.AdminAddr, Handler: otelhttp.NewHandler(mux, "newt-admin"), @@ -687,6 +719,24 @@ func runNewtMain(ctx context.Context) { defer func() { telemetry.IncSiteRegistration(ctx, regResult) }() + + // Deduplicate using chainId: if the server echoes back a chainId we have + // already consumed (or one that doesn't match our current pending request), + // throw the message away to avoid setting up the tunnel twice. + var chainData struct { + ChainId string `json:"chainId"` + } + if jsonBytes, err := json.Marshal(msg.Data); err == nil { + _ = json.Unmarshal(jsonBytes, &chainData) + } + if chainData.ChainId != "" { + if chainData.ChainId != pendingRegisterChainId { + logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId) + return + } + pendingRegisterChainId = "" // consume – further duplicates with this id are rejected + } + if stopFunc != nil { stopFunc() // stop the ws from sending more requests stopFunc = nil // reset stopFunc to nil to avoid double stopping @@ -871,8 +921,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Request exit nodes from the server + pingChainId := generateChainId() + pendingPingChainId = pingChainId stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ "noCloud": noCloud, + "chainId": pingChainId, }, 3*time.Second) logger.Info("Tunnel destroyed, ready for reconnection") @@ -901,6 +954,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { logger.Debug("Received ping message") + if stopFunc != nil { stopFunc() // stop the ws from sending more requests stopFunc = nil // reset stopFunc to nil to avoid double stopping @@ -920,6 +974,14 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } exitNodes := exitNodeData.ExitNodes + if exitNodeData.ChainId != "" { + if exitNodeData.ChainId != pendingPingChainId { + logger.Debug("Discarding duplicate/stale newt/ping/exitNodes (chainId=%s, expected=%s)", exitNodeData.ChainId, pendingPingChainId) + return + } + pendingPingChainId = "" // consume – further duplicates with this id are rejected + } + if len(exitNodes) == 0 { logger.Info("No exit nodes provided") return @@ -952,10 +1014,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( }, } + chainId := generateChainId() + pendingRegisterChainId = chainId stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, + "chainId": chainId, }, 2*time.Second) return @@ -1055,10 +1120,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Send the ping results to the cloud for selection + chainId := generateChainId() + pendingRegisterChainId = chainId stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, + "chainId": chainId, }, 2*time.Second) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) @@ -1708,8 +1776,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( stopFunc() } // request from the server the list of nodes to ping + pingChainId := generateChainId() + pendingPingChainId = pingChainId stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ "noCloud": noCloud, + "chainId": pingChainId, }, 3*time.Second) logger.Debug("Requesting exit nodes from server") @@ -1721,10 +1792,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Send registration message to the server for backward compatibility + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId err := client.SendMessage(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "newtVersion": newtVersion, "backwardsCompatible": true, + "chainId": bcChainId, }) sendBlueprint(client) diff --git a/proxy/manager.go b/proxy/manager.go index 0619e80..5566589 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -21,7 +21,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" ) -const errUnsupportedProtoFmt = "unsupported protocol: %s" +const ( + errUnsupportedProtoFmt = "unsupported protocol: %s" + maxUDPPacketSize = 65507 +) // Target represents a proxy target with its address and port type Target struct { @@ -105,13 +108,9 @@ func classifyProxyError(err error) string { if errors.Is(err, net.ErrClosed) { return "closed" } - if ne, ok := err.(net.Error); ok { - if ne.Timeout() { - return "timeout" - } - if ne.Temporary() { - return "temporary" - } + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + return "timeout" } msg := strings.ToLower(err.Error()) switch { @@ -437,14 +436,6 @@ func (pm *ProxyManager) Stop() error { pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) } - // // Clear the target maps - // for k := range pm.tcpTargets { - // delete(pm.tcpTargets, k) - // } - // for k := range pm.udpTargets { - // delete(pm.udpTargets, k) - // } - // Give active connections a chance to close gracefully time.Sleep(100 * time.Millisecond) @@ -498,7 +489,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) if !pm.running { return } - if ne, ok := err.(net.Error); ok && !ne.Temporary() { + if errors.Is(err, net.ErrClosed) { logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) return } @@ -564,7 +555,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) } func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { - buffer := make([]byte, 65507) // Max UDP packet size + buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size clientConns := make(map[string]*net.UDPConn) var clientsMutex sync.RWMutex @@ -583,7 +574,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { } // Check for connection closed conditions - if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { logger.Info("UDP connection closed, stopping proxy handler") // Clean up existing client connections @@ -662,10 +653,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed) }() - buffer := make([]byte, 65507) + buffer := make([]byte, maxUDPPacketSize) for { n, _, err := targetConn.ReadFromUDP(buffer) if err != nil { + // Connection closed is normal during cleanup + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { + return // defer will handle cleanup, result stays "success" + } logger.Error("Error reading from target: %v", err) result = "failure" return // defer will handle cleanup