mirror of
https://github.com/fosrl/newt.git
synced 2026-03-26 12:36:45 +00:00
Allow passing public dns into resolve
This commit is contained in:
@@ -162,9 +162,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
|||||||
useNativeInterface: useNativeInterface,
|
useNativeInterface: useNativeInterface,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the holepunch manager with ResolveDomain function
|
// Create the holepunch manager
|
||||||
// We'll need to pass a domain resolver function
|
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String(), nil)
|
||||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId, "newt", key.PublicKey().String())
|
|
||||||
|
|
||||||
// Register websocket handlers
|
// Register websocket handlers
|
||||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ type Manager struct {
|
|||||||
clientType string
|
clientType string
|
||||||
exitNodes map[string]ExitNode // key is endpoint
|
exitNodes map[string]ExitNode // key is endpoint
|
||||||
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
updateChan chan struct{} // signals the goroutine to refresh exit nodes
|
||||||
|
publicDNS []string
|
||||||
|
|
||||||
sendHolepunchInterval time.Duration
|
sendHolepunchInterval time.Duration
|
||||||
sendHolepunchIntervalMin time.Duration
|
sendHolepunchIntervalMin time.Duration
|
||||||
@@ -49,12 +50,13 @@ const defaultSendHolepunchIntervalMax = 60 * time.Second
|
|||||||
const defaultSendHolepunchIntervalMin = 1 * time.Second
|
const defaultSendHolepunchIntervalMin = 1 * time.Second
|
||||||
|
|
||||||
// NewManager creates a new hole punch manager
|
// NewManager creates a new hole punch manager
|
||||||
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string) *Manager {
|
func NewManager(sharedBind *bind.SharedBind, ID string, clientType string, publicKey string, publicDNS []string) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
sharedBind: sharedBind,
|
sharedBind: sharedBind,
|
||||||
ID: ID,
|
ID: ID,
|
||||||
clientType: clientType,
|
clientType: clientType,
|
||||||
publicKey: publicKey,
|
publicKey: publicKey,
|
||||||
|
publicDNS: publicDNS,
|
||||||
exitNodes: make(map[string]ExitNode),
|
exitNodes: make(map[string]ExitNode),
|
||||||
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
|
sendHolepunchInterval: defaultSendHolepunchIntervalMin,
|
||||||
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
|
sendHolepunchIntervalMin: defaultSendHolepunchIntervalMin,
|
||||||
@@ -281,7 +283,13 @@ func (m *Manager) TriggerHolePunch() error {
|
|||||||
// Send hole punch to all exit nodes
|
// Send hole punch to all exit nodes
|
||||||
successCount := 0
|
successCount := 0
|
||||||
for _, exitNode := range currentExitNodes {
|
for _, exitNode := range currentExitNodes {
|
||||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
var host string
|
||||||
|
var err error
|
||||||
|
if len(m.publicDNS) > 0 {
|
||||||
|
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
|
||||||
|
} else {
|
||||||
|
host, err = util.ResolveDomain(exitNode.Endpoint)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
continue
|
continue
|
||||||
@@ -392,7 +400,13 @@ func (m *Manager) runMultipleExitNodes() {
|
|||||||
|
|
||||||
var resolvedNodes []resolvedExitNode
|
var resolvedNodes []resolvedExitNode
|
||||||
for _, exitNode := range currentExitNodes {
|
for _, exitNode := range currentExitNodes {
|
||||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
var host string
|
||||||
|
var err error
|
||||||
|
if len(m.publicDNS) > 0 {
|
||||||
|
host, err = util.ResolveDomainUpstream(exitNode.Endpoint, m.publicDNS)
|
||||||
|
} else {
|
||||||
|
host, err = util.ResolveDomain(exitNode.Endpoint)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ type cachedAddr struct {
|
|||||||
// HolepunchTester monitors holepunch connectivity using magic packets
|
// HolepunchTester monitors holepunch connectivity using magic packets
|
||||||
type HolepunchTester struct {
|
type HolepunchTester struct {
|
||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
|
publicDNS []string
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
running bool
|
running bool
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
@@ -84,9 +85,10 @@ type pendingRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
|
// NewHolepunchTester creates a new holepunch tester using the given SharedBind
|
||||||
func NewHolepunchTester(sharedBind *bind.SharedBind) *HolepunchTester {
|
func NewHolepunchTester(sharedBind *bind.SharedBind, publicDNS []string) *HolepunchTester {
|
||||||
return &HolepunchTester{
|
return &HolepunchTester{
|
||||||
sharedBind: sharedBind,
|
sharedBind: sharedBind,
|
||||||
|
publicDNS: publicDNS,
|
||||||
addrCache: make(map[string]*cachedAddr),
|
addrCache: make(map[string]*cachedAddr),
|
||||||
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
|
addrCacheTTL: 5 * time.Minute, // Cache addresses for 5 minutes
|
||||||
}
|
}
|
||||||
@@ -169,7 +171,13 @@ func (t *HolepunchTester) resolveEndpoint(endpoint string) (*net.UDPAddr, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the endpoint
|
// Resolve the endpoint
|
||||||
host, err := util.ResolveDomain(endpoint)
|
var host string
|
||||||
|
var err error
|
||||||
|
if len(t.publicDNS) > 0 {
|
||||||
|
host, err = util.ResolveDomainUpstream(endpoint, t.publicDNS)
|
||||||
|
} else {
|
||||||
|
host, err = util.ResolveDomain(endpoint)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
host = endpoint
|
host = endpoint
|
||||||
}
|
}
|
||||||
|
|||||||
94
util/util.go
94
util/util.go
@@ -1,6 +1,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
@@ -14,6 +15,99 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func ResolveDomainUpstream(domain string, publicDNS []string) (string, error) {
|
||||||
|
// trim whitespace
|
||||||
|
domain = strings.TrimSpace(domain)
|
||||||
|
|
||||||
|
// Remove any protocol prefix if present (do this first, before splitting host/port)
|
||||||
|
domain = strings.TrimPrefix(domain, "http://")
|
||||||
|
domain = strings.TrimPrefix(domain, "https://")
|
||||||
|
|
||||||
|
// if there are any trailing slashes, remove them
|
||||||
|
domain = strings.TrimSuffix(domain, "/")
|
||||||
|
|
||||||
|
// Check if there's a port in the domain
|
||||||
|
host, port, err := net.SplitHostPort(domain)
|
||||||
|
if err != nil {
|
||||||
|
// No port found, use the domain as is
|
||||||
|
host = domain
|
||||||
|
port = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if host is already an IP address (IPv4 or IPv6)
|
||||||
|
// For IPv6, the host from SplitHostPort will already have brackets stripped
|
||||||
|
// but if there was no port, we need to handle bracketed IPv6 addresses
|
||||||
|
cleanHost := strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
|
||||||
|
if ip := net.ParseIP(cleanHost); ip != nil {
|
||||||
|
// It's already an IP address, no need to resolve
|
||||||
|
ipAddr := ip.String()
|
||||||
|
if port != "" {
|
||||||
|
return net.JoinHostPort(ipAddr, port), nil
|
||||||
|
}
|
||||||
|
return ipAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup IP addresses using the upstream DNS servers if provided
|
||||||
|
var ips []net.IP
|
||||||
|
if len(publicDNS) > 0 {
|
||||||
|
var lastErr error
|
||||||
|
for _, server := range publicDNS {
|
||||||
|
// Ensure the upstream DNS address has a port
|
||||||
|
dnsAddr := server
|
||||||
|
if _, _, err := net.SplitHostPort(dnsAddr); err != nil {
|
||||||
|
// No port specified, default to 53
|
||||||
|
dnsAddr = net.JoinHostPort(server, "53")
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver := &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
d := net.Dialer{}
|
||||||
|
return d.DialContext(ctx, "udp", dnsAddr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ips, lastErr = resolver.LookupIP(context.Background(), "ip", host)
|
||||||
|
if lastErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
return "", fmt.Errorf("DNS lookup failed using all upstream servers: %v", lastErr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ips, err = net.LookupIP(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ips) == 0 {
|
||||||
|
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the first IPv4 address if available
|
||||||
|
var ipAddr string
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
ipAddr = ipv4.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no IPv4 found, use the first IP (might be IPv6)
|
||||||
|
if ipAddr == "" {
|
||||||
|
ipAddr = ips[0].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add port back if it existed
|
||||||
|
if port != "" {
|
||||||
|
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
func ResolveDomain(domain string) (string, error) {
|
func ResolveDomain(domain string) (string, error) {
|
||||||
// trim whitespace
|
// trim whitespace
|
||||||
domain = strings.TrimSpace(domain)
|
domain = strings.TrimSpace(domain)
|
||||||
|
|||||||
Reference in New Issue
Block a user