mirror of
https://github.com/fosrl/newt.git
synced 2026-03-27 21:16:41 +00:00
Compare commits
7 Commits
dependabot
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1057013b50 | ||
|
|
a2683eb385 | ||
|
|
d3722c2519 | ||
|
|
8fda35db4f | ||
|
|
de4353f2e6 | ||
|
|
13448f76aa | ||
|
|
836144aebf |
@@ -2,6 +2,8 @@ package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -34,6 +36,7 @@ type WgConfig struct {
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
Targets []Target `json:"targets"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -82,7 +85,8 @@ type WireGuardService struct {
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
stopGetConfig func()
|
||||
pendingConfigChainId string
|
||||
// Netstack fields
|
||||
tun tun.Device
|
||||
tnet *netstack2.Net
|
||||
@@ -107,6 +111,13 @@ type WireGuardService struct {
|
||||
wgTesterServer *wgtester.Server
|
||||
}
|
||||
|
||||
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -442,9 +453,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
s.stopGetConfig()
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
chainId := generateChainId()
|
||||
s.pendingConfigChainId = chainId
|
||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": s.key.PublicKey().String(),
|
||||
"port": s.Port,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Debug("Requesting WireGuard configuration from remote server")
|
||||
@@ -469,6 +483,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using chainId: discard responses that don't match the
|
||||
// pending request, or that we have already processed.
|
||||
if config.ChainId != "" {
|
||||
if config.ChainId != s.pendingConfigChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
|
||||
return
|
||||
}
|
||||
s.pendingConfigChainId = "" // consume – further duplicates are rejected
|
||||
}
|
||||
|
||||
s.config = config
|
||||
|
||||
if s.stopGetConfig != nil {
|
||||
|
||||
@@ -287,9 +287,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
||||
}
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||
// Send registration message to the server for backward compatibility
|
||||
bcChainId := generateChainId()
|
||||
pendingRegisterChainId = bcChainId
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"backwardsCompatible": true,
|
||||
"chainId": bcChainId,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -365,11 +367,12 @@ func (m *Monitor) performHealthCheck(target *Target) {
|
||||
target.LastCheck = time.Now()
|
||||
target.LastError = ""
|
||||
|
||||
// Build URL
|
||||
url := fmt.Sprintf("%s://%s", target.Config.Scheme, target.Config.Hostname)
|
||||
// Build URL (use net.JoinHostPort to properly handle IPv6 addresses with ports)
|
||||
host := target.Config.Hostname
|
||||
if target.Config.Port > 0 {
|
||||
url = fmt.Sprintf("%s:%d", url, target.Config.Port)
|
||||
host = net.JoinHostPort(target.Config.Hostname, strconv.Itoa(target.Config.Port))
|
||||
}
|
||||
url := fmt.Sprintf("%s://%s", target.Config.Scheme, host)
|
||||
if target.Config.Path != "" {
|
||||
if !strings.HasPrefix(target.Config.Path, "/") {
|
||||
url += "/"
|
||||
|
||||
57
main.go
57
main.go
@@ -3,13 +3,16 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -45,6 +48,7 @@ type WgData struct {
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
Targets TargetsByType `json:"targets"`
|
||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
type TargetsByType struct {
|
||||
@@ -127,6 +131,7 @@ var (
|
||||
publicKey wgtypes.Key
|
||||
pingStopChan chan struct{}
|
||||
stopFunc func()
|
||||
pendingRegisterChainId string
|
||||
healthFile string
|
||||
useNativeInterface bool
|
||||
authorizedKeysFile string
|
||||
@@ -147,6 +152,7 @@ var (
|
||||
adminAddr string
|
||||
region string
|
||||
metricsAsyncBytes bool
|
||||
pprofEnabled bool
|
||||
blueprintFile string
|
||||
noCloud bool
|
||||
|
||||
@@ -159,6 +165,13 @@ var (
|
||||
tlsPrivateKey string
|
||||
)
|
||||
|
||||
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Check for subcommands first (only principals exits early)
|
||||
if len(os.Args) > 1 {
|
||||
@@ -225,6 +238,7 @@ func runNewtMain(ctx context.Context) {
|
||||
adminAddrEnv := os.Getenv("NEWT_ADMIN_ADDR")
|
||||
regionEnv := os.Getenv("NEWT_REGION")
|
||||
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
||||
pprofEnabledEnv := os.Getenv("NEWT_PPROF_ENABLED")
|
||||
|
||||
disableClientsEnv := os.Getenv("DISABLE_CLIENTS")
|
||||
disableClients = disableClientsEnv == "true"
|
||||
@@ -390,6 +404,14 @@ func runNewtMain(ctx context.Context) {
|
||||
metricsAsyncBytes = v
|
||||
}
|
||||
}
|
||||
// pprof debug endpoint toggle
|
||||
if pprofEnabledEnv == "" {
|
||||
flag.BoolVar(&pprofEnabled, "pprof", false, "Enable pprof debug endpoints on admin server")
|
||||
} else {
|
||||
if v, err := strconv.ParseBool(pprofEnabledEnv); err == nil {
|
||||
pprofEnabled = v
|
||||
}
|
||||
}
|
||||
// Optional region flag (resource attribute)
|
||||
if regionEnv == "" {
|
||||
flag.StringVar(®ion, "region", "", "Optional region resource attribute (also NEWT_REGION)")
|
||||
@@ -485,6 +507,14 @@ func runNewtMain(ctx context.Context) {
|
||||
if tel.PrometheusHandler != nil {
|
||||
mux.Handle("/metrics", tel.PrometheusHandler)
|
||||
}
|
||||
if pprofEnabled {
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
logger.Info("pprof debugging enabled on %s/debug/pprof/", tcfg.AdminAddr)
|
||||
}
|
||||
admin := &http.Server{
|
||||
Addr: tcfg.AdminAddr,
|
||||
Handler: otelhttp.NewHandler(mux, "newt-admin"),
|
||||
@@ -687,6 +717,24 @@ func runNewtMain(ctx context.Context) {
|
||||
defer func() {
|
||||
telemetry.IncSiteRegistration(ctx, regResult)
|
||||
}()
|
||||
|
||||
// Deduplicate using chainId: if the server echoes back a chainId we have
|
||||
// already consumed (or one that doesn't match our current pending request),
|
||||
// throw the message away to avoid setting up the tunnel twice.
|
||||
var chainData struct {
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
|
||||
_ = json.Unmarshal(jsonBytes, &chainData)
|
||||
}
|
||||
if chainData.ChainId != "" {
|
||||
if chainData.ChainId != pendingRegisterChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
|
||||
return
|
||||
}
|
||||
pendingRegisterChainId = "" // consume – further duplicates with this id are rejected
|
||||
}
|
||||
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
@@ -952,10 +1000,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
},
|
||||
}
|
||||
|
||||
chainId := generateChainId()
|
||||
pendingRegisterChainId = chainId
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
return
|
||||
@@ -1055,10 +1106,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Send the ping results to the cloud for selection
|
||||
chainId := generateChainId()
|
||||
pendingRegisterChainId = chainId
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||
@@ -1721,10 +1775,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Send registration message to the server for backward compatibility
|
||||
bcChainId := generateChainId()
|
||||
pendingRegisterChainId = bcChainId
|
||||
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"newtVersion": newtVersion,
|
||||
"backwardsCompatible": true,
|
||||
"chainId": bcChainId,
|
||||
})
|
||||
|
||||
sendBlueprint(client)
|
||||
|
||||
@@ -21,7 +21,10 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
)
|
||||
|
||||
const errUnsupportedProtoFmt = "unsupported protocol: %s"
|
||||
const (
|
||||
errUnsupportedProtoFmt = "unsupported protocol: %s"
|
||||
maxUDPPacketSize = 65507
|
||||
)
|
||||
|
||||
// Target represents a proxy target with its address and port
|
||||
type Target struct {
|
||||
@@ -105,13 +108,9 @@ func classifyProxyError(err error) string {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return "closed"
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok {
|
||||
if ne.Timeout() {
|
||||
return "timeout"
|
||||
}
|
||||
if ne.Temporary() {
|
||||
return "temporary"
|
||||
}
|
||||
var ne net.Error
|
||||
if errors.As(err, &ne) && ne.Timeout() {
|
||||
return "timeout"
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
switch {
|
||||
@@ -437,14 +436,6 @@ func (pm *ProxyManager) Stop() error {
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
}
|
||||
|
||||
// // Clear the target maps
|
||||
// for k := range pm.tcpTargets {
|
||||
// delete(pm.tcpTargets, k)
|
||||
// }
|
||||
// for k := range pm.udpTargets {
|
||||
// delete(pm.udpTargets, k)
|
||||
// }
|
||||
|
||||
// Give active connections a chance to close gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -498,7 +489,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && !ne.Temporary() {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
|
||||
return
|
||||
}
|
||||
@@ -564,7 +555,7 @@ func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
buffer := make([]byte, 65507) // Max UDP packet size
|
||||
buffer := make([]byte, maxUDPPacketSize) // Max UDP packet size
|
||||
clientConns := make(map[string]*net.UDPConn)
|
||||
var clientsMutex sync.RWMutex
|
||||
|
||||
@@ -583,7 +574,7 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
}
|
||||
|
||||
// Check for connection closed conditions
|
||||
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
logger.Info("UDP connection closed, stopping proxy handler")
|
||||
|
||||
// Clean up existing client connections
|
||||
@@ -662,10 +653,14 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
telemetry.IncProxyConnectionEvent(context.Background(), tunnelID, "udp", telemetry.ProxyConnectionClosed)
|
||||
}()
|
||||
|
||||
buffer := make([]byte, 65507)
|
||||
buffer := make([]byte, maxUDPPacketSize)
|
||||
for {
|
||||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
// Connection closed is normal during cleanup
|
||||
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
|
||||
return // defer will handle cleanup, result stays "success"
|
||||
}
|
||||
logger.Error("Error reading from target: %v", err)
|
||||
result = "failure"
|
||||
return // defer will handle cleanup
|
||||
|
||||
Reference in New Issue
Block a user