Merge branch 'main' into logging-provision

This commit is contained in:
Owen
2026-03-29 21:19:53 -07:00
6 changed files with 226 additions and 24 deletions

View File

@@ -2,6 +2,8 @@ package clients
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@@ -34,6 +36,7 @@ type WgConfig struct {
IpAddress string `json:"ipAddress"` IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"` Peers []Peer `json:"peers"`
Targets []Target `json:"targets"` Targets []Target `json:"targets"`
ChainId string `json:"chainId"`
} }
type Target struct { type Target struct {
@@ -83,7 +86,8 @@ type WireGuardService struct {
host string host string
serverPubKey string serverPubKey string
token string token string
stopGetConfig func() stopGetConfig func()
pendingConfigChainId string
// Netstack fields // Netstack fields
tun tun.Device tun tun.Device
tnet *netstack2.Net tnet *netstack2.Net
@@ -108,6 +112,13 @@ type WireGuardService struct {
wgTesterServer *wgtester.Server wgTesterServer *wgtester.Server
} }
// generateChainId generates a random chain ID for deduplicating round-trip messages.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@@ -162,9 +173,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
useNativeInterface: useNativeInterface, useNativeInterface: useNativeInterface,
} }
// Create the holepunch manager with ResolveDomain function // Create the holepunch manager
// We'll need to pass a domain resolver function service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil)
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
// Register websocket handlers // Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
@@ -452,9 +462,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig() s.stopGetConfig()
s.stopGetConfig = nil s.stopGetConfig = nil
} }
chainId := generateChainId()
s.pendingConfigChainId = chainId
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(), "publicKey": s.key.PublicKey().String(),
"port": s.Port, "port": s.Port,
"chainId": chainId,
}, 2*time.Second) }, 2*time.Second)
logger.Debug("Requesting WireGuard configuration from remote server") logger.Debug("Requesting WireGuard configuration from remote server")
@@ -479,6 +492,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Error unmarshaling target data: %v", err) logger.Info("Error unmarshaling target data: %v", err)
return return
} }
// Deduplicate using chainId: discard responses that don't match the
// pending request, or that we have already processed.
if config.ChainId != "" {
if config.ChainId != s.pendingConfigChainId {
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
return
}
s.pendingConfigChainId = "" // consume further duplicates are rejected
}
s.config = config s.config = config
if s.stopGetConfig != nil { if s.stopGetConfig != nil {

View File

@@ -286,11 +286,18 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
if tunnelID != "" { if tunnelID != "" {
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout) telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
} }
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"chainId": pingChainId,
}, 3*time.Second)
// Send registration message to the server for backward compatibility // Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage("newt/wg/register", map[string]interface{}{ err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"backwardsCompatible": true, "backwardsCompatible": true,
"chainId": bcChainId,
}) })
if err != nil { if err != nil {
logger.Error("Failed to send registration message: %v", err) logger.Error("Failed to send registration message: %v", err)

View File

@@ -27,16 +27,17 @@ type ExitNode struct {
// Manager handles UDP hole punching operations // Manager handles UDP hole punching operations
type Manager struct { type Manager struct {
mu sync.Mutex mu sync.Mutex
running bool running bool
stopChan chan struct{} stopChan chan struct{}
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
ID string ID string
token string token string
publicKey string publicKey string
clientType string clientType string
exitNodes map[string]ExitNode // key is endpoint exitNodes map[string]ExitNode // key is endpoint
updateChan chan struct{} // signals the goroutine to refresh exit nodes updateChan chan struct{} // signals the goroutine to refresh exit nodes
publicDNS []string
sendHolepunchInterval time.Duration sendHolepunchInterval time.Duration
sendHolepunchIntervalMin time.Duration sendHolepunchIntervalMin time.Duration
@@ -49,12 +50,13 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second
const defaultSendHolepunchIntervalMin = 1 * time.Second const defaultSendHolepunchIntervalMin = 1 * time.Second
// NewManager creates a new hole punch manager // NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager { func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager {
return &Manager{ return &Manager{
sharedBind: sharedBind, sharedBind: sharedBind,
ID: ID, ID: ID,
clientType: clientType, clientType: clientType,
publicKey: publicKey, publicKey: publicKey,
publicDNS: publicDNS,
exitNodes: make(map[string]ExitNode), exitNodes: make(map[string]ExitNode),
sendHolepunchInterval: defaultSendHolepunchIntervalMin, sendHolepunchInterval: defaultSendHolepunchIntervalMin,
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin, sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
@@ -281,7 +283,13 @@ func (m *Manager) TriggerHolePunch() error {
// Send hole punch to all exit nodes // Send hole punch to all exit nodes
successCount := 0 successCount := 0
for _, exitNode := range currentExitNodes { for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint) var host string
var err error
if len(m.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
} else {
host, err = util.ResolveDomain(exitNode.Endpoint)
}
if err != nil { if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue continue
@@ -392,7 +400,13 @@ func (m *Manager) runMultipleExitNodes() {
var resolvedNodes []resolvedExitNode var resolvedNodes []resolvedExitNode
for _, exitNode := range currentExitNodes { for _, exitNode := range currentExitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint) var host string
var err error
if len(m.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
} else {
host, err = util.ResolveDomain(exitNode.Endpoint)
}
if err != nil { if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue continue

View File

@@ -49,10 +49,11 @@ type cachedAddr struct {
// HolepunchTester monitors holepunch connectivity using magic packets // HolepunchTester monitors holepunch connectivity using magic packets
type HolepunchTester struct { type HolepunchTester struct {
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
mu sync.RWMutex publicDNS []string
running bool mu sync.RWMutex
stopChan chan struct{} running bool
stopChan chan struct{}
// Pending requests waiting for responses (key: echo data as string) // Pending requests waiting for responses (key: echo data as string)
pendingRequests sync.Map // map[string]*pendingRequest pendingRequests sync.Map // map[string]*pendingRequest
@@ -84,9 +85,10 @@ type pendingRequest struct {
} }
// NewHolepunchTester creates a new holepunch tester using the given SharedBind // NewHolepunchTester creates a new holepunch tester using the given SharedBind
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester { func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester {
return &HolepunchTester{ return &HolepunchTester{
sharedBind: sharedBind, sharedBind: sharedBind,
publicDNS: publicDNS,
addrCache: make(map[string]*cachedAddr), addrCache: make(map[string]*cachedAddr),
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
} }
@@ -169,7 +171,13 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error)
} }
// Resolve the endpoint // Resolve the endpoint
host, err := util.ResolveDomain(endpoint) var host string
var err error
if len(t.publicDNS) > 0 {
host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS)
} else {
host, err = util.ResolveDomain(endpoint)
}
if err != nil { if err != nil {
host = endpoint host = endpoint
} }

55
main.go
View File

@@ -3,7 +3,9 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
@@ -46,6 +48,7 @@ type WgData struct {
TunnelIP string `json:"tunnelIP"` TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"` Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
ChainId string `json:"chainId"`
} }
type TargetsByType struct { type TargetsByType struct {
@@ -59,6 +62,7 @@ type TargetData struct {
type ExitNodeData struct { type ExitNodeData struct {
ExitNodes []ExitNode `json:"exitNodes"` ExitNodes []ExitNode `json:"exitNodes"`
ChainId string `json:"chainId"`
} }
// ExitNode represents an exit node with an ID, endpoint, and weight. // ExitNode represents an exit node with an ID, endpoint, and weight.
@@ -128,6 +132,8 @@ var (
publicKey wgtypes.Key publicKey wgtypes.Key
pingStopChan chan struct{} pingStopChan chan struct{}
stopFunc func() stopFunc func()
pendingRegisterChainId string
pendingPingChainId string
healthFile string healthFile string
useNativeInterface bool useNativeInterface bool
authorizedKeysFile string authorizedKeysFile string
@@ -167,6 +173,13 @@ var (
configFile string configFile string
) )
// generateChainId generates a random chain ID for deduplicating round-trip messages.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func main() { func main() {
// Check for subcommands first (only principals exits early) // Check for subcommands first (only principals exits early)
if len(os.Args) > 1 { if len(os.Args) > 1 {
@@ -727,6 +740,24 @@ func runNewtMain(ctx context.Context) {
defer func() { defer func() {
telemetry.IncSiteRegistration(ctx, regResult) telemetry.IncSiteRegistration(ctx, regResult)
}() }()
// Deduplicate using chainId: if the server echoes back a chainId we have
// already consumed (or one that doesn't match our current pending request),
// throw the message away to avoid setting up the tunnel twice.
var chainData struct {
ChainId string `json:"chainId"`
}
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
_ = json.Unmarshal(jsonBytes, &chainData)
}
if chainData.ChainId != "" {
if chainData.ChainId != pendingRegisterChainId {
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
return
}
pendingRegisterChainId = "" // consume further duplicates with this id are rejected
}
if stopFunc != nil { if stopFunc != nil {
stopFunc() // stop the ws from sending more requests stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -911,8 +942,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
// Request exit nodes from the server // Request exit nodes from the server
pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"noCloud": noCloud, "noCloud": noCloud,
"chainId": pingChainId,
}, 3*time.Second) }, 3*time.Second)
logger.Info("Tunnel destroyed, ready for reconnection") logger.Info("Tunnel destroyed, ready for reconnection")
@@ -941,6 +975,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
logger.Debug("Received ping message") logger.Debug("Received ping message")
if stopFunc != nil { if stopFunc != nil {
stopFunc() // stop the ws from sending more requests stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -960,6 +995,14 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
exitNodes := exitNodeData.ExitNodes exitNodes := exitNodeData.ExitNodes
if exitNodeData.ChainId != "" {
if exitNodeData.ChainId != pendingPingChainId {
logger.Debug("Discarding duplicate/stale newt/ping/exitNodes (chainId=%s, expected=%s)", exitNodeData.ChainId, pendingPingChainId)
return
}
pendingPingChainId = "" // consume further duplicates with this id are rejected
}
if len(exitNodes) == 0 { if len(exitNodes) == 0 {
logger.Info("No exit nodes provided") logger.Info("No exit nodes provided")
return return
@@ -992,10 +1035,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}, },
} }
chainId := generateChainId()
pendingRegisterChainId = chainId
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
"chainId": chainId,
}, 2*time.Second) }, 2*time.Second)
return return
@@ -1095,10 +1141,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
// Send the ping results to the cloud for selection // Send the ping results to the cloud for selection
chainId := generateChainId()
pendingRegisterChainId = chainId
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
"chainId": chainId,
}, 2*time.Second) }, 2*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
@@ -1748,8 +1797,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
stopFunc() stopFunc()
} }
// request from the server the list of nodes to ping // request from the server the list of nodes to ping
pingChainId := generateChainId()
pendingPingChainId = pingChainId
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
"noCloud": noCloud, "noCloud": noCloud,
"chainId": pingChainId,
}, 3*time.Second) }, 3*time.Second)
logger.Debug("Requesting exit nodes from server") logger.Debug("Requesting exit nodes from server")
@@ -1761,10 +1813,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
// Send registration message to the server for backward compatibility // Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage(topicWGRegister, map[string]interface{}{ err := client.SendMessage(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"newtVersion": newtVersion, "newtVersion": newtVersion,
"backwardsCompatible": true, "backwardsCompatible": true,
"chainId": bcChainId,
}) })
sendBlueprint(client) sendBlueprint(client)

View File

@@ -1,6 +1,7 @@
package util package util
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
@@ -14,6 +15,99 @@ import (
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
) )
func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) {
// trim whitespace
domain = strings.TrimSpace(domain)
// Remove any protocol prefix if present (do this first, before splitting host/port)
domain = strings.TrimPrefix(domain, "http://")
domain = strings.TrimPrefix(domain, "https://")
// if there are any trailing slashes, remove them
domain = strings.TrimSuffix(domain, "/")
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Check if host is already an IP address (IPv4 or IPv6)
// For IPv6, the host from SplitHostPort will already have brackets stripped
// but if there was no port, we need to handle bracketed IPv6 addresses
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
if ip := net.ParseIP(cleanHost); ip != nil {
// It's already an IP address, no need to resolve
ipAddr := ip.String()
if port != "" {
return net.JoinHostPort(ipAddr, port), nil
}
return ipAddr, nil
}
// Lookup IP addresses using the upstream DNS servers if provided
var ips []net.IP
if len(publicDNS) > 0 {
var lastErr error
for _, server := range publicDNS {
// Ensure the upstream DNS address has a port
dnsAddr := server
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
// No port specified, default to 53
dnsAddr = net.JoinHostPort(server, "53")
}
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", dnsAddr)
},
}
ips, lastErr = resolver.LookupIP(context.Background(), "ip", host)
if lastErr == nil {
break
}
}
if lastErr != nil {
return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr)
}
} else {
ips, err = net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func ResolveDomain(domain string) (string, error) { func ResolveDomain(domain string) (string, error) {
// trim whitespace // trim whitespace
domain = strings.TrimSpace(domain) domain = strings.TrimSpace(domain)