Merge pull request #285 from fosrl/dev

1.10.4
This commit is contained in:
Owen Schwartz
2026-03-29 12:25:10 -07:00
committed by GitHub
5 changed files with 129 additions and 25 deletions

View File

@@ -2,6 +2,8 @@ package clients
import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"net"
@@ -34,6 +36,7 @@ type WgConfig struct {
IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"`
Targets []Target `json:"targets"`
ChainId string `json:"chainId"`
}
type Target struct {
@@ -82,7 +85,8 @@ type WireGuardService struct {
host string
serverPubKey string
token string
stopGetConfig func()
stopGetConfig func()
pendingConfigChainId string
// Netstack fields
tun tun.Device
tnet *netstack2.Net
@@ -107,6 +111,13 @@ type WireGuardService struct {
wgTesterServer *wgtester.Server
}
// generateChainId generates a random chain ID for deduplicating round-trip messages.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
@@ -441,9 +452,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig()
s.stopGetConfig = nil
}
chainId := generateChainId()
s.pendingConfigChainId = chainId
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(),
"port": s.Port,
"chainId": chainId,
}, 2*time.Second)
logger.Debug("Requesting WireGuard configuration from remote server")
@@ -468,6 +482,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Error unmarshaling target data: %v", err)
return
}
// Deduplicate using chainId: discard responses that don't match the
// pending request, or that we have already processed.
if config.ChainId != "" {
if config.ChainId != s.pendingConfigChainId {
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
return
}
s.pendingConfigChainId = "" // consume further duplicates are rejected
}
s.config = config
if s.stopGetConfig != nil {

View File

@@ -285,11 +285,18 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
if tunnelID != "" {
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
}
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"chainId": pingChainId,
}, 3*time.Second)
// Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"backwardsCompatible": true,
"chainId": bcChainId,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)

View File

@@ -5,7 +5,9 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@@ -365,11 +367,12 @@ func (m *Monitor) performHealthCheck(target *Target) {
target.LastCheck = time.Now()
target.LastError = ""
// Build URL
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
host := target.Config.Hostname
if target.Config.Port > 0 {
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
}
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
if target.Config.Path != "" {
if !strings.HasPrefix(target.Config.Path, "/") {
url += "/"

74
main.go
View File

@@ -3,13 +3,16 @@ package main
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"encoding/hex"
"encoding/json"
"errors"
"flag"
"fmt"
"net"
"net/http"
"net/http/pprof"
"net/netip"
"os"
"os/signal"
@@ -45,6 +48,7 @@ type WgData struct {
TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
ChainId string `json:"chainId"`
}
type TargetsByType struct {
@@ -58,6 +62,7 @@ type TargetData struct {
type ExitNodeData struct {
ExitNodes []ExitNode `json:"exitNodes"`
ChainId string `json:"chainId"`
}
// ExitNode represents an exit node with an ID, endpoint, and weight.
@@ -127,6 +132,8 @@ var (
publicKey wgtypes.Key
pingStopChan chan struct{}
stopFunc func()
pendingRegisterChainId string
pendingPingChainId string
healthFile string
useNativeInterface bool
authorizedKeysFile string
@@ -147,6 +154,7 @@ var (
adminAddr string
region string
metricsAsyncBytes bool
pprofEnabled bool
blueprintFile string
noCloud bool
@@ -159,6 +167,13 @@ var (
tlsPrivateKey string
)
// generateChainId generates a random chain ID for deduplicating round-trip messages.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func main() {
// Check for subcommands first (only principals exits early)
if len(os.Args) > 1 {
@@ -225,6 +240,7 @@ func runNewtMain(ctx context.Context) {
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
regionEnv := os.Getenv("NEWT_REGION")
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
disableClients = disableClientsEnv == "true"
@@ -390,6 +406,14 @@ func runNewtMain(ctx context.Context) {
metricsAsyncBytes = v
}
}
// pprof debug endpoint toggle
if pprofEnabledEnv == "" {
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
} else {
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
pprofEnabled = v
}
}
// Optional region flag (resource attribute)
if regionEnv == "" {
flag.StringVar(&region, "region", "", "Optional region resource attribute (also NEWT_REGION)")
@@ -485,6 +509,14 @@ func runNewtMain(ctx context.Context) {
if tel.PrometheusHandler != nil {
mux.Handle("/metrics", tel.PrometheusHandler)
}
if pprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
}
admin := &http.Server{
Addr: tcfg.AdminAddr,
Handler: otelhttp.NewHandler(mux, "newt-admin"),
@@ -687,6 +719,24 @@ func runNewtMain(ctx context.Context) {
defer func() {
telemetry.IncSiteRegistration(ctx, regResult)
}()
// Deduplicate using chainId: if the server echoes back a chainId we have
// already consumed (or one that doesn't match our current pending request),
// throw the message away to avoid setting up the tunnel twice.
var chainData struct {
ChainId string `json:"chainId"`
}
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
_ = json.Unmarshal(jsonBytes, &chainData)
}
if chainData.ChainId != "" {
if chainData.ChainId != pendingRegisterChainId {
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
return
}
pendingRegisterChainId = "" // consume further duplicates with this id are rejected
}
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -871,8 +921,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
// Request exit nodes from the server
pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"noCloud": noCloud,
"chainId": pingChainId,
}, 3*time.Second)
logger.Info("Tunnel destroyed, ready for reconnection")
@@ -901,6 +954,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
logger.Debug("Received ping message")
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -920,6 +974,14 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
exitNodes := exitNodeData.ExitNodes
if exitNodeData.ChainId != "" {
if exitNodeData.ChainId != pendingPingChainId {
logger.Debug("Discarding duplicate/stale newt/ping/exitNodes (chainId=%s, expected=%s)", exitNodeData.ChainId, pendingPingChainId)
return
}
pendingPingChainId = "" // consume further duplicates with this id are rejected
}
if len(exitNodes) == 0 {
logger.Info("No exit nodes provided")
return
@@ -952,10 +1014,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
},
}
chainId := generateChainId()
pendingRegisterChainId = chainId
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
"chainId": chainId,
}, 2*time.Second)
return
@@ -1055,10 +1120,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
// Send the ping results to the cloud for selection
chainId := generateChainId()
pendingRegisterChainId = chainId
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(),
"pingResults": pingResults,
"newtVersion": newtVersion,
"chainId": chainId,
}, 2*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
@@ -1708,8 +1776,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
stopFunc()
}
// request from the server the list of nodes to ping
pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"noCloud": noCloud,
"chainId": pingChainId,
}, 3*time.Second)
logger.Debug("Requesting exit nodes from server")
@@ -1721,10 +1792,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}
// Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(),
"newtVersion": newtVersion,
"backwardsCompatible": true,
"chainId": bcChainId,
})
sendBlueprint(client)

View File

@@ -21,7 +21,10 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
const errUnsupportedProtoFmt = "unsupported protocol: %s"
const (
errUnsupportedProtoFmt = "unsupported protocol: %s"
maxUDPPacketSize = 65507
)
// Target represents a proxy target with its address and port
type Target struct {
@@ -105,13 +108,9 @@ func classifyProxyError(err error) string {
if errors.Is(err, net.ErrClosed) {
return "closed"
}
if ne, ok := err.(net.Error); ok {
if ne.Timeout() {
return "timeout"
}
if ne.Temporary() {
return "temporary"
}
var ne net.Error
if errors.As(err, &ne) && ne.Timeout() {
return "timeout"
}
msg := strings.ToLower(err.Error())
switch {
@@ -437,14 +436,6 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond)
@@ -498,7 +489,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
if !pm.running {
return
}
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
if errors.Is(err, net.ErrClosed) {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
@@ -564,7 +555,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
}
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, 65507) // Max UDP packet size
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex
@@ -583,7 +574,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
}
// Check for connection closed conditions
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
logger.Info("UDP connection closed, stopping proxy handler")
// Clean up existing client connections
@@ -662,10 +653,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
}()
buffer := make([]byte, 65507)
buffer := make([]byte, maxUDPPacketSize)
for {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
// Connection closed is normal during cleanup
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return // defer will handle cleanup, result stays "success"
}
logger.Error("Error reading from target: %v", err)
result = "failure"
return // defer will handle cleanup