mirror of
https://github.com/fosrl/newt.git
synced 2026-03-30 06:26:41 +00:00
Merge branch 'main' into logging-provision
This commit is contained in:
@@ -2,6 +2,8 @@ package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -34,6 +36,7 @@ type WgConfig struct {
|
||||
IpAddress string `json:"ipAddress"`
|
||||
Peers []Peer `json:"peers"`
|
||||
Targets []Target `json:"targets"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
@@ -83,7 +86,8 @@ type WireGuardService struct {
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
stopGetConfig func()
|
||||
pendingConfigChainId string
|
||||
// Netstack fields
|
||||
tun tun.Device
|
||||
tnet *netstack2.Net
|
||||
@@ -108,6 +112,13 @@ type WireGuardService struct {
|
||||
wgTesterServer *wgtester.Server
|
||||
}
|
||||
|
||||
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func NewWireGuardService(interfaceName string, port uint16, mtu int, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -162,9 +173,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
||||
useNativeInterface: useNativeInterface,
|
||||
}
|
||||
|
||||
// Create the holepunch manager with ResolveDomain function
|
||||
// We'll need to pass a domain resolver function
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
|
||||
// Create the holepunch manager
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil)
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
@@ -452,9 +462,12 @@ func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
s.stopGetConfig()
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
chainId := generateChainId()
|
||||
s.pendingConfigChainId = chainId
|
||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": s.key.PublicKey().String(),
|
||||
"port": s.Port,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Debug("Requesting WireGuard configuration from remote server")
|
||||
@@ -479,6 +492,17 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deduplicate using chainId: discard responses that don't match the
|
||||
// pending request, or that we have already processed.
|
||||
if config.ChainId != "" {
|
||||
if config.ChainId != s.pendingConfigChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/wg/get-config response (chainId=%s, expected=%s)", config.ChainId, s.pendingConfigChainId)
|
||||
return
|
||||
}
|
||||
s.pendingConfigChainId = "" // consume – further duplicates are rejected
|
||||
}
|
||||
|
||||
s.config = config
|
||||
|
||||
if s.stopGetConfig != nil {
|
||||
|
||||
@@ -286,11 +286,18 @@ func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Clien
|
||||
if tunnelID != "" {
|
||||
telemetry.IncReconnect(context.Background(), tunnelID, "client", telemetry.ReasonTimeout)
|
||||
}
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||
pingChainId := generateChainId()
|
||||
pendingPingChainId = pingChainId
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||
"chainId": pingChainId,
|
||||
}, 3*time.Second)
|
||||
// Send registration message to the server for backward compatibility
|
||||
bcChainId := generateChainId()
|
||||
pendingRegisterChainId = bcChainId
|
||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"backwardsCompatible": true,
|
||||
"chainId": bcChainId,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
|
||||
@@ -27,16 +27,17 @@ type ExitNode struct {
|
||||
|
||||
// Manager handles UDP hole punching operations
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
ID string
|
||||
token string
|
||||
publicKey string
|
||||
clientType string
|
||||
exitNodes map[string]ExitNode // key is endpoint
|
||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
ID string
|
||||
token string
|
||||
publicKey string
|
||||
clientType string
|
||||
exitNodes map[string]ExitNode // key is endpoint
|
||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||
publicDNS []string
|
||||
|
||||
sendHolepunchInterval time.Duration
|
||||
sendHolepunchIntervalMin time.Duration
|
||||
@@ -49,12 +50,13 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second
|
||||
const defaultSendHolepunchIntervalMin = 1 * time.Second
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
|
||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
ID: ID,
|
||||
clientType: clientType,
|
||||
publicKey: publicKey,
|
||||
publicDNS: publicDNS,
|
||||
exitNodes: make(map[string]ExitNode),
|
||||
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
|
||||
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
|
||||
@@ -281,7 +283,13 @@ func (m *Manager) TriggerHolePunch() error {
|
||||
// Send hole punch to all exit nodes
|
||||
successCount := 0
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
var host string
|
||||
var err error
|
||||
if len(m.publicDNS) > 0 {
|
||||
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
|
||||
} else {
|
||||
host, err = util.ResolveDomain(exitNode.Endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
@@ -392,7 +400,13 @@ func (m *Manager) runMultipleExitNodes() {
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range currentExitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
var host string
|
||||
var err error
|
||||
if len(m.publicDNS) > 0 {
|
||||
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
|
||||
} else {
|
||||
host, err = util.ResolveDomain(exitNode.Endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
|
||||
@@ -49,10 +49,11 @@ type cachedAddr struct {
|
||||
|
||||
// HolepunchTester monitors holepunch connectivity using magic packets
|
||||
type HolepunchTester struct {
|
||||
sharedBind *bind.SharedBind
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
publicDNS []string
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
|
||||
// Pending requests waiting for responses (key: echo data as string)
|
||||
pendingRequests sync.Map // map[string]*pendingRequest
|
||||
@@ -84,9 +85,10 @@ type pendingRequest struct {
|
||||
}
|
||||
|
||||
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
|
||||
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
|
||||
func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester {
|
||||
return &HolepunchTester{
|
||||
sharedBind: sharedBind,
|
||||
publicDNS: publicDNS,
|
||||
addrCache: make(map[string]*cachedAddr),
|
||||
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
|
||||
}
|
||||
@@ -169,7 +171,13 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error)
|
||||
}
|
||||
|
||||
// Resolve the endpoint
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
var host string
|
||||
var err error
|
||||
if len(t.publicDNS) > 0 {
|
||||
host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS)
|
||||
} else {
|
||||
host, err = util.ResolveDomain(endpoint)
|
||||
}
|
||||
if err != nil {
|
||||
host = endpoint
|
||||
}
|
||||
|
||||
55
main.go
55
main.go
@@ -3,7 +3,9 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
@@ -46,6 +48,7 @@ type WgData struct {
|
||||
TunnelIP string `json:"tunnelIP"`
|
||||
Targets TargetsByType `json:"targets"`
|
||||
HealthCheckTargets []healthcheck.Config `json:"healthCheckTargets"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
type TargetsByType struct {
|
||||
@@ -59,6 +62,7 @@ type TargetData struct {
|
||||
|
||||
type ExitNodeData struct {
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
|
||||
// ExitNode represents an exit node with an ID, endpoint, and weight.
|
||||
@@ -128,6 +132,8 @@ var (
|
||||
publicKey wgtypes.Key
|
||||
pingStopChan chan struct{}
|
||||
stopFunc func()
|
||||
pendingRegisterChainId string
|
||||
pendingPingChainId string
|
||||
healthFile string
|
||||
useNativeInterface bool
|
||||
authorizedKeysFile string
|
||||
@@ -167,6 +173,13 @@ var (
|
||||
configFile string
|
||||
)
|
||||
|
||||
// generateChainId generates a random chain ID for deduplicating round-trip messages.
|
||||
func generateChainId() string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Check for subcommands first (only principals exits early)
|
||||
if len(os.Args) > 1 {
|
||||
@@ -727,6 +740,24 @@ func runNewtMain(ctx context.Context) {
|
||||
defer func() {
|
||||
telemetry.IncSiteRegistration(ctx, regResult)
|
||||
}()
|
||||
|
||||
// Deduplicate using chainId: if the server echoes back a chainId we have
|
||||
// already consumed (or one that doesn't match our current pending request),
|
||||
// throw the message away to avoid setting up the tunnel twice.
|
||||
var chainData struct {
|
||||
ChainId string `json:"chainId"`
|
||||
}
|
||||
if jsonBytes, err := json.Marshal(msg.Data); err == nil {
|
||||
_ = json.Unmarshal(jsonBytes, &chainData)
|
||||
}
|
||||
if chainData.ChainId != "" {
|
||||
if chainData.ChainId != pendingRegisterChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/wg/connect (chainId=%s, expected=%s)", chainData.ChainId, pendingRegisterChainId)
|
||||
return
|
||||
}
|
||||
pendingRegisterChainId = "" // consume – further duplicates with this id are rejected
|
||||
}
|
||||
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
@@ -911,8 +942,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Request exit nodes from the server
|
||||
pingChainId := generateChainId()
|
||||
pendingPingChainId = pingChainId
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||
"noCloud": noCloud,
|
||||
"chainId": pingChainId,
|
||||
}, 3*time.Second)
|
||||
|
||||
logger.Info("Tunnel destroyed, ready for reconnection")
|
||||
@@ -941,6 +975,7 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
|
||||
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received ping message")
|
||||
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
@@ -960,6 +995,14 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
exitNodes := exitNodeData.ExitNodes
|
||||
|
||||
if exitNodeData.ChainId != "" {
|
||||
if exitNodeData.ChainId != pendingPingChainId {
|
||||
logger.Debug("Discarding duplicate/stale newt/ping/exitNodes (chainId=%s, expected=%s)", exitNodeData.ChainId, pendingPingChainId)
|
||||
return
|
||||
}
|
||||
pendingPingChainId = "" // consume – further duplicates with this id are rejected
|
||||
}
|
||||
|
||||
if len(exitNodes) == 0 {
|
||||
logger.Info("No exit nodes provided")
|
||||
return
|
||||
@@ -992,10 +1035,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
},
|
||||
}
|
||||
|
||||
chainId := generateChainId()
|
||||
pendingRegisterChainId = chainId
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
return
|
||||
@@ -1095,10 +1141,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Send the ping results to the cloud for selection
|
||||
chainId := generateChainId()
|
||||
pendingRegisterChainId = chainId
|
||||
stopFunc = client.SendMessageInterval(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
"newtVersion": newtVersion,
|
||||
"chainId": chainId,
|
||||
}, 2*time.Second)
|
||||
|
||||
logger.Debug("Sent exit node ping results to cloud for selection: pingResults=%+v", pingResults)
|
||||
@@ -1748,8 +1797,11 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
stopFunc()
|
||||
}
|
||||
// request from the server the list of nodes to ping
|
||||
pingChainId := generateChainId()
|
||||
pendingPingChainId = pingChainId
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{
|
||||
"noCloud": noCloud,
|
||||
"chainId": pingChainId,
|
||||
}, 3*time.Second)
|
||||
logger.Debug("Requesting exit nodes from server")
|
||||
|
||||
@@ -1761,10 +1813,13 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
|
||||
}
|
||||
|
||||
// Send registration message to the server for backward compatibility
|
||||
bcChainId := generateChainId()
|
||||
pendingRegisterChainId = bcChainId
|
||||
err := client.SendMessage(topicWGRegister, map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"newtVersion": newtVersion,
|
||||
"backwardsCompatible": true,
|
||||
"chainId": bcChainId,
|
||||
})
|
||||
|
||||
sendBlueprint(client)
|
||||
|
||||
94
util/util.go
94
util/util.go
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
@@ -14,6 +15,99 @@ import (
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
)
|
||||
|
||||
func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) {
|
||||
// trim whitespace
|
||||
domain = strings.TrimSpace(domain)
|
||||
|
||||
// Remove any protocol prefix if present (do this first, before splitting host/port)
|
||||
domain = strings.TrimPrefix(domain, "http://")
|
||||
domain = strings.TrimPrefix(domain, "https://")
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
domain = strings.TrimSuffix(domain, "/")
|
||||
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Check if host is already an IP address (IPv4 or IPv6)
|
||||
// For IPv6, the host from SplitHostPort will already have brackets stripped
|
||||
// but if there was no port, we need to handle bracketed IPv6 addresses
|
||||
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
|
||||
if ip := net.ParseIP(cleanHost); ip != nil {
|
||||
// It's already an IP address, no need to resolve
|
||||
ipAddr := ip.String()
|
||||
if port != "" {
|
||||
return net.JoinHostPort(ipAddr, port), nil
|
||||
}
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
// Lookup IP addresses using the upstream DNS servers if provided
|
||||
var ips []net.IP
|
||||
if len(publicDNS) > 0 {
|
||||
var lastErr error
|
||||
for _, server := range publicDNS {
|
||||
// Ensure the upstream DNS address has a port
|
||||
dnsAddr := server
|
||||
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
|
||||
// No port specified, default to 53
|
||||
dnsAddr = net.JoinHostPort(server, "53")
|
||||
}
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, "udp", dnsAddr)
|
||||
},
|
||||
}
|
||||
ips, lastErr = resolver.LookupIP(context.Background(), "ip", host)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr)
|
||||
}
|
||||
} else {
|
||||
ips, err = net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
|
||||
func ResolveDomain(domain string) (string, error) {
|
||||
// trim whitespace
|
||||
domain = strings.TrimSpace(domain)
|
||||
|
||||
Reference in New Issue
Block a user