mirror of
https://github.com/fosrl/newt.git
synced 2026-04-02 07:56:39 +00:00
Add chainId based dedup
This commit is contained in:
@@ -2,6 +2,8 @@ package clients
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -34,6 +36,7 @@ type WgConfig struct {
|
|||||||
IpAddress string `json:"ipAddress"`
|
IpAddress string `json:"ipAddress"`
|
||||||
Peers []Peer `json:"peers"`
|
Peers []Peer `json:"peers"`
|
||||||
Targets []Target `json:"targets"`
|
Targets []Target `json:"targets"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Target struct {
|
type Target struct {
|
||||||
@@ -82,7 +85,8 @@ type WireGuardService struct {
|
|||||||
host string
|
host string
|
||||||
serverPubKey string
|
serverPubKey string
|
||||||
token string
|
token string
|
||||||
stopGetConfig func()
|
stopGetConfig func()
|
||||||
|
pendingConfigChainId string
|
||||||
// Netstack fields
|
// Netstack fields
|
||||||
tun tun.Device
|
tun tun.Device
|
||||||
tnet *netstack2.Net
|
tnet *netstack2.Net
|
||||||
@@ -107,6 +111,13 @@ type WireGuardService struct {
|
|||||||
wgTesterServer *wgtester.Server
|
wgTesterServer *wgtester.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||||
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -442,9 +453,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
|
|||||||
s.stopGetConfig()
|
s.stopGetConfig()
|
||||||
s.stopGetConfig = nil
|
s.stopGetConfig = nil
|
||||||
}
|
}
|
||||||
|
chainId := generateChainId()
|
||||||
|
s.pendingConfigChainId = chainId
|
||||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||||
"publicKey": s.key.PublicKey().String(),
|
"publicKey": s.key.PublicKey().String(),
|
||||||
"port": s.Port,
|
"port": s.Port,
|
||||||
|
"chainId": chainId,
|
||||||
}, 2*time.Second)
|
}, 2*time.Second)
|
||||||
|
|
||||||
logger.Debug("Requesting WireGuard configuration from remote server")
|
logger.Debug("Requesting WireGuard configuration from remote server")
|
||||||
@@ -469,6 +483,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
logger.Info("Error unmarshaling target data: %v", err)
|
logger.Info("Error unmarshaling target data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deduplicate using chainId: discard responses that don't match the
|
||||||
|
// pending request, or that we have already processed.
|
||||||
|
if config.ChainId != "" {
|
||||||
|
if config.ChainId != s.pendingConfigChainId {
|
||||||
|
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.pendingConfigChainId = "" // consume – further duplicates are rejected
|
||||||
|
}
|
||||||
|
|
||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
if s.stopGetConfig != nil {
|
if s.stopGetConfig != nil {
|
||||||
|
|||||||
@@ -287,9 +287,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
|||||||
}
|
}
|
||||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||||
// Send registration message to the server for backward compatibility
|
// Send registration message to the server for backward compatibility
|
||||||
|
bcChainId := generateChainId()
|
||||||
|
pendingRegisterChainId = bcChainId
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"backwardsCompatible": true,
|
"backwardsCompatible": true,
|
||||||
|
"chainId": bcChainId,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
logger.Error("Failed to send registration message: %v", err)
|
||||||
|
|||||||
38
main.go
38
main.go
@@ -3,7 +3,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
@@ -46,6 +48,7 @@ type WgData struct {
|
|||||||
TunnelIP string `json:"tunnelIP"`
|
TunnelIP string `json:"tunnelIP"`
|
||||||
Targets TargetsByType `json:"targets"`
|
Targets TargetsByType `json:"targets"`
|
||||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TargetsByType struct {
|
type TargetsByType struct {
|
||||||
@@ -128,6 +131,7 @@ var (
|
|||||||
publicKey wgtypes.Key
|
publicKey wgtypes.Key
|
||||||
pingStopChan chan struct{}
|
pingStopChan chan struct{}
|
||||||
stopFunc func()
|
stopFunc func()
|
||||||
|
pendingRegisterChainId string
|
||||||
healthFile string
|
healthFile string
|
||||||
useNativeInterface bool
|
useNativeInterface bool
|
||||||
authorizedKeysFile string
|
authorizedKeysFile string
|
||||||
@@ -161,6 +165,13 @@ var (
|
|||||||
tlsPrivateKey string
|
tlsPrivateKey string
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||||
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Check for subcommands first (only principals exits early)
|
// Check for subcommands first (only principals exits early)
|
||||||
if len(os.Args) > 1 {
|
if len(os.Args) > 1 {
|
||||||
@@ -706,6 +717,24 @@ func runNewtMain(ctx context.Context) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
telemetry.IncSiteRegistration(ctx, regResult)
|
telemetry.IncSiteRegistration(ctx, regResult)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Deduplicate using chainId: if the server echoes back a chainId we have
|
||||||
|
// already consumed (or one that doesn't match our current pending request),
|
||||||
|
// throw the message away to avoid setting up the tunnel twice.
|
||||||
|
var chainData struct {
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
|
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
|
||||||
|
_ = json.Unmarshal(jsonBytes, &chainData)
|
||||||
|
}
|
||||||
|
if chainData.ChainId != "" {
|
||||||
|
if chainData.ChainId != pendingRegisterChainId {
|
||||||
|
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pendingRegisterChainId = "" // consume – further duplicates with this id are rejected
|
||||||
|
}
|
||||||
|
|
||||||
if stopFunc != nil {
|
if stopFunc != nil {
|
||||||
stopFunc() // stop the ws from sending more requests
|
stopFunc() // stop the ws from sending more requests
|
||||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||||
@@ -971,10 +1000,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chainId := generateChainId()
|
||||||
|
pendingRegisterChainId = chainId
|
||||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"pingResults": pingResults,
|
"pingResults": pingResults,
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
|
"chainId": chainId,
|
||||||
}, 2*time.Second)
|
}, 2*time.Second)
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -1074,10 +1106,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send the ping results to the cloud for selection
|
// Send the ping results to the cloud for selection
|
||||||
|
chainId := generateChainId()
|
||||||
|
pendingRegisterChainId = chainId
|
||||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"pingResults": pingResults,
|
"pingResults": pingResults,
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
|
"chainId": chainId,
|
||||||
}, 2*time.Second)
|
}, 2*time.Second)
|
||||||
|
|
||||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||||
@@ -1740,10 +1775,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send registration message to the server for backward compatibility
|
// Send registration message to the server for backward compatibility
|
||||||
|
bcChainId := generateChainId()
|
||||||
|
pendingRegisterChainId = bcChainId
|
||||||
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
||||||
"publicKey": publicKey.String(),
|
"publicKey": publicKey.String(),
|
||||||
"newtVersion": newtVersion,
|
"newtVersion": newtVersion,
|
||||||
"backwardsCompatible": true,
|
"backwardsCompatible": true,
|
||||||
|
"chainId": bcChainId,
|
||||||
})
|
})
|
||||||
|
|
||||||
sendBlueprint(client)
|
sendBlueprint(client)
|
||||||
|
|||||||
Reference in New Issue
Block a user