diff --git a/clients/clients.go b/clients/clients.go index 4c64dbd..6130dcb 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -2,6 +2,8 @@ package clients import ( "context" + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" "net" @@ -34,6 +36,7 @@ type WgConfig struct { IpAddress string `json:"ipAddress"` Peers []Peer `json:"peers"` Targets []Target `json:"targets"` + ChainId string `json:"chainId"` } type Target struct { @@ -82,7 +85,8 @@ type WireGuardService struct { host string serverPubKey string token string - stopGetConfig func() + stopGetConfig func() + pendingConfigChainId string // Netstack fields tun tun.Device tnet *netstack2.Net @@ -107,6 +111,13 @@ type WireGuardService struct { wgTesterServer *wgtester.Server } +// generateChainId generates a random chain ID for deduplicating round-trip messages. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -442,9 +453,12 @@ func (s *WireGuardService) LoadRemoteConfig() error { s.stopGetConfig() s.stopGetConfig = nil } + chainId := generateChainId() + s.pendingConfigChainId = chainId s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ "publicKey": s.key.PublicKey().String(), "port": s.Port, + "chainId": chainId, }, 2*time.Second) logger.Debug("Requesting WireGuard configuration from remote server") @@ -469,6 +483,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { logger.Info("Error unmarshaling target data: %v", err) return } + + // Deduplicate using chainId: discard responses that don't match the + // pending request, or that we have already processed. + if config.ChainId != "" { + if config.ChainId != s.pendingConfigChainId { + logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId) + return + } + s.pendingConfigChainId = "" // consume – further duplicates are rejected + } + s.config = config if s.stopGetConfig != nil { diff --git a/common.go b/common.go index 4701411..707eefa 100644 --- a/common.go +++ b/common.go @@ -287,9 +287,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien } stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) // Send registration message to the server for backward compatibility + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId err := client.SendMessage("newt/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "backwardsCompatible": true, + "chainId": bcChainId, }) if err != nil { logger.Error("Failed to send registration message: %v", err) diff --git a/main.go b/main.go index 3646a27..a79c70d 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,9 @@ package main import ( "bytes" "context" + "crypto/rand" "crypto/tls" + "encoding/hex" "encoding/json" "errors" "flag" @@ -46,6 +48,7 @@ type WgData struct { TunnelIP string `json:"tunnelIP"` Targets TargetsByType `json:"targets"` HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"` + ChainId string `json:"chainId"` } type TargetsByType struct { @@ -128,6 +131,7 @@ var ( publicKey wgtypes.Key pingStopChan chan struct{} stopFunc func() + pendingRegisterChainId string healthFile string useNativeInterface bool authorizedKeysFile string @@ -161,6 +165,13 @@ var ( tlsPrivateKey string ) +// generateChainId generates a random chain ID for deduplicating round-trip messages. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func main() { // Check for subcommands first (only principals exits early) if len(os.Args) > 1 { @@ -706,6 +717,24 @@ func runNewtMain(ctx context.Context) { defer func() { telemetry.IncSiteRegistration(ctx, regResult) }() + + // Deduplicate using chainId: if the server echoes back a chainId we have + // already consumed (or one that doesn't match our current pending request), + // throw the message away to avoid setting up the tunnel twice. + var chainData struct { + ChainId string `json:"chainId"` + } + if jsonBytes, err := json.Marshal(msg.Data); err == nil { + _ = json.Unmarshal(jsonBytes, &chainData) + } + if chainData.ChainId != "" { + if chainData.ChainId != pendingRegisterChainId { + logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId) + return + } + pendingRegisterChainId = "" // consume – further duplicates with this id are rejected + } + if stopFunc != nil { stopFunc() // stop the ws from sending more requests stopFunc = nil // reset stopFunc to nil to avoid double stopping @@ -971,10 +1000,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( }, } + chainId := generateChainId() + pendingRegisterChainId = chainId stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, + "chainId": chainId, }, 2*time.Second) return @@ -1074,10 +1106,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Send the ping results to the cloud for selection + chainId := generateChainId() + pendingRegisterChainId = chainId stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "pingResults": pingResults, "newtVersion": newtVersion, + "chainId": chainId, }, 2*time.Second) logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults) @@ -1740,10 +1775,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( } // Send registration message to the server for backward compatibility + bcChainId := generateChainId() + pendingRegisterChainId = bcChainId err := client.SendMessage(topicWGRegister, map[string]interface{}{ "publicKey": publicKey.String(), "newtVersion": newtVersion, "backwardsCompatible": true, + "chainId": bcChainId, }) sendBlueprint(client)