mirror of
https://github.com/fosrl/newt.git
synced 2026-03-29 05:56:38 +00:00
Add chainId based dedup
This commit is contained in:
@@ -2,6 +2,8 @@ package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -34,6 +36,7 @@ type WgConfig struct {
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
Targets []Target `json:"targets"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -82,7 +85,8 @@ type WireGuardService struct {
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
stopGetConfig func()
|
||||
pendingConfigChainId string
|
||||
// Netstack fields
|
||||
tun tun.Device
|
||||
tnet *netstack2.Net
|
||||
@@ -107,6 +111,13 @@ type WireGuardService struct {
|
||||
wgTesterServer *wgtester.Server
|
||||
}
|
||||
|
||||
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -442,9 +453,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
s.stopGetConfig()
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
chainId := generateChainId()
|
||||
s.pendingConfigChainId = chainId
|
||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": s.key.PublicKey().String(),
|
||||
"port": s.Port,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Debug("Requesting WireGuard configuration from remote server")
|
||||
@@ -469,6 +483,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using chainId: discard responses that don't match the
|
||||
// pending request, or that we have already processed.
|
||||
if config.ChainId != "" {
|
||||
if config.ChainId != s.pendingConfigChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
|
||||
return
|
||||
}
|
||||
s.pendingConfigChainId = "" // consume – further duplicates are rejected
|
||||
}
|
||||
|
||||
s.config = config
|
||||
|
||||
if s.stopGetConfig != nil {
|
||||
|
||||
@@ -287,9 +287,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
||||
}
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||
// Send registration message to the server for backward compatibility
|
||||
bcChainId := generateChainId()
|
||||
pendingRegisterChainId = bcChainId
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"backwardsCompatible": true,
|
||||
"chainId": bcChainId,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
|
||||
38
main.go
38
main.go
@@ -3,7 +3,9 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
@@ -46,6 +48,7 @@ type WgData struct {
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
Targets TargetsByType `json:"targets"`
|
||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
type TargetsByType struct {
|
||||
@@ -128,6 +131,7 @@ var (
|
||||
publicKey wgtypes.Key
|
||||
pingStopChan chan struct{}
|
||||
stopFunc func()
|
||||
pendingRegisterChainId string
|
||||
healthFile string
|
||||
useNativeInterface bool
|
||||
authorizedKeysFile string
|
||||
@@ -161,6 +165,13 @@ var (
|
||||
tlsPrivateKey string
|
||||
)
|
||||
|
||||
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Check for subcommands first (only principals exits early)
|
||||
if len(os.Args) > 1 {
|
||||
@@ -706,6 +717,24 @@ func runNewtMain(ctx context.Context) {
|
||||
defer func() {
|
||||
telemetry.IncSiteRegistration(ctx, regResult)
|
||||
}()
|
||||
|
||||
// Deduplicate using chainId: if the server echoes back a chainId we have
|
||||
// already consumed (or one that doesn't match our current pending request),
|
||||
// throw the message away to avoid setting up the tunnel twice.
|
||||
var chainData struct {
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
|
||||
_ = json.Unmarshal(jsonBytes, &chainData)
|
||||
}
|
||||
if chainData.ChainId != "" {
|
||||
if chainData.ChainId != pendingRegisterChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
|
||||
return
|
||||
}
|
||||
pendingRegisterChainId = "" // consume – further duplicates with this id are rejected
|
||||
}
|
||||
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
@@ -971,10 +1000,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
},
|
||||
}
|
||||
|
||||
chainId := generateChainId()
|
||||
pendingRegisterChainId = chainId
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
return
|
||||
@@ -1074,10 +1106,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Send the ping results to the cloud for selection
|
||||
chainId := generateChainId()
|
||||
pendingRegisterChainId = chainId
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||
@@ -1740,10 +1775,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Send registration message to the server for backward compatibility
|
||||
bcChainId := generateChainId()
|
||||
pendingRegisterChainId = bcChainId
|
||||
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"newtVersion": newtVersion,
|
||||
"backwardsCompatible": true,
|
||||
"chainId": bcChainId,
|
||||
})
|
||||
|
||||
sendBlueprint(client)
|
||||
|
||||
Reference in New Issue
Block a user