Add chainId based dedup

This commit is contained in:
Owen
2026-03-27 11:55:34 -07:00
parent a2683eb385
commit 1057013b50
3 changed files with 67 additions and 1 deletions

View File

@@ -2,6 +2,8 @@ package clients
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@@ -34,6 +36,7 @@ type WgConfig struct {
IpAddress string `json:"ipAddress"` IpAddress string `json:"ipAddress"`
Peers []Peer `json:"peers"` Peers []Peer `json:"peers"`
Targets []Target `json:"targets"` Targets []Target `json:"targets"`
ChainId string `json:"chainId"`
} }
type Target struct { type Target struct {
@@ -83,6 +86,7 @@ type WireGuardService struct {
serverPubKey string serverPubKey string
token string token string
stopGetConfig func() stopGetConfig func()
pendingConfigChainId string
// Netstack fields // Netstack fields
tun tun.Device tun tun.Device
tnet *netstack2.Net tnet *netstack2.Net
@@ -107,6 +111,13 @@ type WireGuardService struct {
wgTesterServer *wgtester.Server wgTesterServer *wgtester.Server
} }
// generateChainId generates a random chain ID for deduplicating round-trip messages.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
@@ -442,9 +453,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig() s.stopGetConfig()
s.stopGetConfig = nil s.stopGetConfig = nil
} }
chainId := generateChainId()
s.pendingConfigChainId = chainId
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(), "publicKey": s.key.PublicKey().String(),
"port": s.Port, "port": s.Port,
"chainId": chainId,
}, 2*time.Second) }, 2*time.Second)
logger.Debug("Requesting WireGuard configuration from remote server") logger.Debug("Requesting WireGuard configuration from remote server")
@@ -469,6 +483,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
logger.Info("Error unmarshaling target data: %v", err) logger.Info("Error unmarshaling target data: %v", err)
return return
} }
// Deduplicate using chainId: discard responses that don't match the
// pending request, or that we have already processed.
if config.ChainId != "" {
if config.ChainId != s.pendingConfigChainId {
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
return
}
s.pendingConfigChainId = "" // consume further duplicates are rejected
}
s.config = config s.config = config
if s.stopGetConfig != nil { if s.stopGetConfig != nil {

View File

@@ -287,9 +287,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
} }
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
// Send registration message to the server for backward compatibility // Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage("newt/wg/register", map[string]interface{}{ err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"backwardsCompatible": true, "backwardsCompatible": true,
"chainId": bcChainId,
}) })
if err != nil { if err != nil {
logger.Error("Failed to send registration message: %v", err) logger.Error("Failed to send registration message: %v", err)

38
main.go
View File

@@ -3,7 +3,9 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
@@ -46,6 +48,7 @@ type WgData struct {
TunnelIP string `json:"tunnelIP"` TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"` Targets TargetsByType `json:"targets"`
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
ChainId string `json:"chainId"`
} }
type TargetsByType struct { type TargetsByType struct {
@@ -128,6 +131,7 @@ var (
publicKey wgtypes.Key publicKey wgtypes.Key
pingStopChan chan struct{} pingStopChan chan struct{}
stopFunc func() stopFunc func()
pendingRegisterChainId string
healthFile string healthFile string
useNativeInterface bool useNativeInterface bool
authorizedKeysFile string authorizedKeysFile string
@@ -161,6 +165,13 @@ var (
tlsPrivateKey string tlsPrivateKey string
) )
// generateChainId generates a random chain ID for deduplicating round-trip messages.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func main() { func main() {
// Check for subcommands first (only principals exits early) // Check for subcommands first (only principals exits early)
if len(os.Args) > 1 { if len(os.Args) > 1 {
@@ -706,6 +717,24 @@ func runNewtMain(ctx context.Context) {
defer func() { defer func() {
telemetry.IncSiteRegistration(ctx, regResult) telemetry.IncSiteRegistration(ctx, regResult)
}() }()
// Deduplicate using chainId: if the server echoes back a chainId we have
// already consumed (or one that doesn't match our current pending request),
// throw the message away to avoid setting up the tunnel twice.
var chainData struct {
ChainId string `json:"chainId"`
}
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
_ = json.Unmarshal(jsonBytes, &chainData)
}
if chainData.ChainId != "" {
if chainData.ChainId != pendingRegisterChainId {
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
return
}
pendingRegisterChainId = "" // consume further duplicates with this id are rejected
}
if stopFunc != nil { if stopFunc != nil {
stopFunc() // stop the ws from sending more requests stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping stopFunc = nil // reset stopFunc to nil to avoid double stopping
@@ -971,10 +1000,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
}, },
} }
chainId := generateChainId()
pendingRegisterChainId = chainId
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
"chainId": chainId,
}, 2*time.Second) }, 2*time.Second)
return return
@@ -1074,10 +1106,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
// Send the ping results to the cloud for selection // Send the ping results to the cloud for selection
chainId := generateChainId()
pendingRegisterChainId = chainId
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"pingResults": pingResults, "pingResults": pingResults,
"newtVersion": newtVersion, "newtVersion": newtVersion,
"chainId": chainId,
}, 2*time.Second) }, 2*time.Second)
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
@@ -1740,10 +1775,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
} }
// Send registration message to the server for backward compatibility // Send registration message to the server for backward compatibility
bcChainId := generateChainId()
pendingRegisterChainId = bcChainId
err := client.SendMessage(topicWGRegister, map[string]interface{}{ err := client.SendMessage(topicWGRegister, map[string]interface{}{
"publicKey": publicKey.String(), "publicKey": publicKey.String(),
"newtVersion": newtVersion, "newtVersion": newtVersion,
"backwardsCompatible": true, "backwardsCompatible": true,
"chainId": bcChainId,
}) })
sendBlueprint(client) sendBlueprint(client)