diff --git a/clients/clients.go b/clients/clients.go index 9223262..78bc0c3 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -2,6 +2,8 @@ package clients import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "net" @@ -34,6 +36,7 @@ type WgConfig struct { IpAddress string `json:"ipAddress"` Peers []Peer `json:"peers"` Targets []Target `json:"targets"` + ChainId string `json:"chainId"` } type Target struct { @@ -83,7 +86,8 @@ type WireGuardService struct { host string serverPubKey string token string - stopGetConfig func() + stopGetConfig func() + pendingConfigChainId string // Netstack fields tun tun.Device tnet *netstack2.Net @@ -108,6 +112,13 @@ type WireGuardService struct { wgTesterServer *wgtester.Server } +// generateChainId generates a random chain ID for deduplicating round-trip messages. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -162,9 +173,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string useNativeInterface: useNativeInterface, } - // Create the holepunch manager with ResolveDomain function - // We'll need to pass a domain resolver function - service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String()) + // Create the holepunch manager + service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil) // Register websocket handlers wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) @@ -452,9 +462,12 @@ func (s *WireGuardService) LoadRemoteConfig() error { s.stopGetConfig() s.stopGetConfig = nil } + chainId := generateChainId() + s.pendingConfigChainId = chainId s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ "publicKey": s.key.PublicKey().String(), "port": s.Port, + "chainId": chainId, }, 2*time.Second) logger.Debug("Requesting WireGuard configuration from remote server") @@ -479,6 +492,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { logger.Info("Error unmarshaling target data: %v", err) return } + + // Deduplicate using chainId: discard responses that don't match the + // pending request, or that we have already processed. + if config.ChainId != "" { + if config.ChainId != s.pendingConfigChainId { + logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId) + return + } + s.pendingConfigChainId = "" // consume – further duplicates are rejected + } + s.config = config if s.stopGetConfig != nil { diff --git a/common.go b/common.go index 34e3cd0..e215813 100644 --- a/common.go +++ b/common.go @@ -286,11 +286,18 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien if tunnelID != "" { telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout) } - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + pingChainId := generateChainId() + pendingPingChainId = pingChainId + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ + "chainId": pingChainId, + }, 3*time.Second) // Send registration message to the server for backward compatibility + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId err := client.SendMessage("newt/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "backwardsCompatible": true, + "chainId": bcChainId, }) if err != nil { logger.Error("Failed to send registration message: %v", err) diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go index 85679a9..8000837 100644 --- a/holepunch/holepunch.go +++ b/holepunch/holepunch.go @@ -27,16 +27,17 @@ type ExitNode struct { // Manager handles UDP hole punching operations type Manager struct { - mu sync.Mutex - running bool - stopChan chan struct{} - sharedBind *bind.SharedBind - ID string - token string - publicKey string - clientType string - exitNodes map[string]ExitNode // key is endpoint - updateChan chan struct{} // signals the goroutine to refresh exit nodes + mu sync.Mutex + running bool + stopChan chan struct{} + sharedBind *bind.SharedBind + ID string + token string + publicKey string + clientType string + exitNodes map[string]ExitNode // key is endpoint + updateChan chan struct{} // signals the goroutine to refresh exit nodes + publicDNS []string sendHolepunchInterval time.Duration sendHolepunchIntervalMin time.Duration @@ -49,12 +50,13 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second const defaultSendHolepunchIntervalMin = 1 * time.Second // NewManager creates a new hole punch manager -func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager { +func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager { return &Manager{ sharedBind: sharedBind, ID: ID, clientType: clientType, publicKey: publicKey, + publicDNS: publicDNS, exitNodes: make(map[string]ExitNode), sendHolepunchInterval: defaultSendHolepunchIntervalMin, sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin, @@ -281,7 +283,13 @@ func (m *Manager) TriggerHolePunch() error { // Send hole punch to all exit nodes successCount := 0 for _, exitNode := range currentExitNodes { - host, err := util.ResolveDomain(exitNode.Endpoint) + var host string + var err error + if len(m.publicDNS) > 0 { + host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS) + } else { + host, err = util.ResolveDomain(exitNode.Endpoint) + } if err != nil { logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) continue @@ -392,7 +400,13 @@ func (m *Manager) runMultipleExitNodes() { var resolvedNodes []resolvedExitNode for _, exitNode := range currentExitNodes { - host, err := util.ResolveDomain(exitNode.Endpoint) + var host string + var err error + if len(m.publicDNS) > 0 { + host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS) + } else { + host, err = util.ResolveDomain(exitNode.Endpoint) + } if err != nil { logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) continue diff --git a/holepunch/tester.go b/holepunch/tester.go index 9fb83df..85b8c89 100644 --- a/holepunch/tester.go +++ b/holepunch/tester.go @@ -49,10 +49,11 @@ type cachedAddr struct { // HolepunchTester monitors holepunch connectivity using magic packets type HolepunchTester struct { - sharedBind *bind.SharedBind - mu sync.RWMutex - running bool - stopChan chan struct{} + sharedBind *bind.SharedBind + publicDNS []string + mu sync.RWMutex + running bool + stopChan chan struct{} // Pending requests waiting for responses (key: echo data as string) pendingRequests sync.Map // map[string]*pendingRequest @@ -84,9 +85,10 @@ type pendingRequest struct { } // NewHolepunchTester creates a new holepunch tester using the given SharedBind -func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester { +func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester { return &HolepunchTester{ sharedBind: sharedBind, + publicDNS: publicDNS, addrCache: make(map[string]*cachedAddr), addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes } @@ -169,7 +171,13 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error) } // Resolve the endpoint - host, err := util.ResolveDomain(endpoint) + var host string + var err error + if len(t.publicDNS) > 0 { + host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS) + } else { + host, err = util.ResolveDomain(endpoint) + } if err != nil { host = endpoint } diff --git a/main.go b/main.go index 0af8773..c573ee2 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,9 @@ package main import ( "bytes" "context" + "crypto/rand" "crypto/tls" + "encoding/hex" "encoding/json" "errors" "flag" @@ -46,6 +48,7 @@ type WgData struct { TunnelIP string `json:"tunnelIP"` Targets TargetsByType `json:"targets"` HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` + ChainId string `json:"chainId"` } type TargetsByType struct { @@ -59,6 +62,7 @@ type TargetData struct { type ExitNodeData struct { ExitNodes []ExitNode `json:"exitNodes"` + ChainId string `json:"chainId"` } // ExitNode represents an exit node with an ID, endpoint, and weight. @@ -128,6 +132,8 @@ var ( publicKey wgtypes.Key pingStopChan chan struct{} stopFunc func() + pendingRegisterChainId string + pendingPingChainId string healthFile string useNativeInterface bool authorizedKeysFile string @@ -167,6 +173,13 @@ var ( configFile string ) +// generateChainId generates a random chain ID for deduplicating round-trip messages. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func main() { // Check for subcommands first (only principals exits early) if len(os.Args) > 1 { @@ -727,6 +740,24 @@ func runNewtMain(ctx context.Context) { defer func() { telemetry.IncSiteRegistration(ctx, regResult) }() + + // Deduplicate using chainId: if the server echoes back a chainId we have + // already consumed (or one that doesn't match our current pending request), + // throw the message away to avoid setting up the tunnel twice. + var chainData struct { + ChainId string `json:"chainId"` + } + if jsonBytes, err := json.Marshal(msg.Data); err == nil { + _ = json.Unmarshal(jsonBytes, &chainData) + } + if chainData.ChainId != "" { + if chainData.ChainId != pendingRegisterChainId { + logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId) + return + } + pendingRegisterChainId = "" // consume – further duplicates with this id are rejected + } + if stopFunc != nil { stopFunc() // stop the ws from sending more requests stopFunc = nil // reset stopFunc to nil to avoid double stopping @@ -911,8 +942,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Request exit nodes from the server + pingChainId := generateChainId() + pendingPingChainId = pingChainId stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ "noCloud": noCloud, + "chainId": pingChainId, }, 3*time.Second) logger.Info("Tunnel destroyed, ready for reconnection") @@ -941,6 +975,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { logger.Debug("Received ping message") + if stopFunc != nil { stopFunc() // stop the ws from sending more requests stopFunc = nil // reset stopFunc to nil to avoid double stopping @@ -960,6 +995,14 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } exitNodes := exitNodeData.ExitNodes + if exitNodeData.ChainId != "" { + if exitNodeData.ChainId != pendingPingChainId { + logger.Debug("Discarding duplicate/stale newt/ping/exitNodes (chainId=%s, expected=%s)", exitNodeData.ChainId, pendingPingChainId) + return + } + pendingPingChainId = "" // consume – further duplicates with this id are rejected + } + if len(exitNodes) == 0 { logger.Info("No exit nodes provided") return @@ -992,10 +1035,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( }, } + chainId := generateChainId() + pendingRegisterChainId = chainId stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, + "chainId": chainId, }, 2*time.Second) return @@ -1095,10 +1141,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Send the ping results to the cloud for selection + chainId := generateChainId() + pendingRegisterChainId = chainId stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, + "chainId": chainId, }, 2*time.Second) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) @@ -1748,8 +1797,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( stopFunc() } // request from the server the list of nodes to ping + pingChainId := generateChainId() + pendingPingChainId = pingChainId stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ "noCloud": noCloud, + "chainId": pingChainId, }, 3*time.Second) logger.Debug("Requesting exit nodes from server") @@ -1761,10 +1813,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Send registration message to the server for backward compatibility + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId err := client.SendMessage(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "newtVersion": newtVersion, "backwardsCompatible": true, + "chainId": bcChainId, }) sendBlueprint(client) diff --git a/util/util.go b/util/util.go index 58221c4..0ce5dee 100644 --- a/util/util.go +++ b/util/util.go @@ -1,6 +1,7 @@ package util import ( + "context" "encoding/base64" "encoding/binary" "encoding/hex" @@ -14,6 +15,99 @@ import ( "golang.zx2c4.com/wireguard/device" ) +func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) { + // trim whitespace + domain = strings.TrimSpace(domain) + + // Remove any protocol prefix if present (do this first, before splitting host/port) + domain = strings.TrimPrefix(domain, "http://") + domain = strings.TrimPrefix(domain, "https://") + + // if there are any trailing slashes, remove them + domain = strings.TrimSuffix(domain, "/") + + // Check if there's a port in the domain + host, port, err := net.SplitHostPort(domain) + if err != nil { + // No port found, use the domain as is + host = domain + port = "" + } + + // Check if host is already an IP address (IPv4 or IPv6) + // For IPv6, the host from SplitHostPort will already have brackets stripped + // but if there was no port, we need to handle bracketed IPv6 addresses + cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[") + if ip := net.ParseIP(cleanHost); ip != nil { + // It's already an IP address, no need to resolve + ipAddr := ip.String() + if port != "" { + return net.JoinHostPort(ipAddr, port), nil + } + return ipAddr, nil + } + + // Lookup IP addresses using the upstream DNS servers if provided + var ips []net.IP + if len(publicDNS) > 0 { + var lastErr error + for _, server := range publicDNS { + // Ensure the upstream DNS address has a port + dnsAddr := server + if _, _, err := net.SplitHostPort(dnsAddr); err != nil { + // No port specified, default to 53 + dnsAddr = net.JoinHostPort(server, "53") + } + + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", dnsAddr) + }, + } + ips, lastErr = resolver.LookupIP(context.Background(), "ip", host) + if lastErr == nil { + break + } + } + if lastErr != nil { + return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr) + } + } else { + ips, err = net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("DNS lookup failed: %v", err) + } + } + + if len(ips) == 0 { + return "", fmt.Errorf("no IP addresses found for domain %s", host) + } + + // Get the first IPv4 address if available + var ipAddr string + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipAddr = ipv4.String() + break + } + } + + // If no IPv4 found, use the first IP (might be IPv6) + if ipAddr == "" { + ipAddr = ips[0].String() + } + + // Add port back if it existed + if port != "" { + ipAddr = net.JoinHostPort(ipAddr, port) + } + + return ipAddr, nil +} + + func ResolveDomain(domain string) (string, error) { // trim whitespace domain = strings.TrimSpace(domain)