Working on more hp

This commit is contained in:
Owen
2025-12-03 20:49:46 -05:00
parent 284f1ce627
commit 8c4d6e2e0a
10 changed files with 157 additions and 142 deletions

View File

@@ -11,6 +11,7 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/fosrl/newt/logger"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
@@ -522,6 +523,7 @@ func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes [
func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool { func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
// Check if this is a test request packet // Check if this is a test request packet
if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) { if len(data) >= MagicTestRequestLen && bytes.HasPrefix(data, MagicTestRequest) {
logger.Debug("Received magic test REQUEST from %s, sending response", addr.String())
// Extract the random data portion to echo back // Extract the random data portion to echo back
echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen] echoData := data[len(MagicTestRequest) : len(MagicTestRequest)+MagicPacketDataLen]
@@ -544,6 +546,7 @@ func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
// Check if this is a test response packet // Check if this is a test response packet
if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) { if len(data) >= MagicTestResponseLen && bytes.HasPrefix(data, MagicTestResponse) {
logger.Debug("Received magic test RESPONSE from %s", addr.String())
// Extract the echoed data // Extract the echoed data
echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen] echoData := data[len(MagicTestResponse) : len(MagicTestResponse)+MagicPacketDataLen]
@@ -557,6 +560,8 @@ func (b *SharedBind) handleMagicPacket(data []byte, addr *net.UDPAddr) bool {
addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()) addrPort = netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())
} }
callback(addrPort, echoData) callback(addrPort, echoData)
} else {
logger.Debug("Magic response received but no callback registered")
} }
return true return true

View File

@@ -2,8 +2,6 @@ package clients
import ( import (
"context" "context"
"encoding/base64"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
@@ -73,7 +71,6 @@ type WireGuardService struct {
client *websocket.Client client *websocket.Client
config WgConfig config WgConfig
key wgtypes.Key key wgtypes.Key
keyFilePath string
newtId string newtId string
lastReadings map[string]PeerReading lastReadings map[string]PeerReading
mu sync.Mutex mu sync.Mutex
@@ -268,10 +265,20 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
return return
} }
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey) // Convert websocket.ExitNode to holepunch.ExitNode
if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil { hpExitNodes := []holepunch.ExitNode{
{
Endpoint: endpoint,
PublicKey: publicKey,
},
}
// Start hole punching using the manager
if err := s.holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil {
logger.Warn("Failed to start hole punch: %v", err) logger.Warn("Failed to start hole punch: %v", err)
} }
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
} }
// StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard. // StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard.
@@ -386,7 +393,7 @@ func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) {
continue continue
} }
logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String()) // logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String())
} }
} }
@@ -477,11 +484,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Parse the IP address and CIDR mask // Parse the IP address and CIDR mask
tunnelIP := netip.MustParseAddr(parts[0]) tunnelIP := netip.MustParseAddr(parts[0])
// Stop any ongoing hole punch operations
if s.holePunchManager != nil {
s.holePunchManager.Stop()
}
var err error var err error
if s.useNativeInterface { if s.useNativeInterface {
@@ -682,15 +684,6 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err) return fmt.Errorf("invalid CIDR %s: %v", target.DestPrefix, err)
} }
var rewriteTo netip.Prefix
if target.RewriteTo != "" {
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
continue
}
}
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
@@ -699,7 +692,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange) logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
} }
@@ -759,6 +752,8 @@ func (s *WireGuardService) handleAddPeer(msg websocket.WSMessage) {
return return
} }
s.holePunchManager.TriggerHolePunch()
err = s.addPeerToDevice(peer) err = s.addPeerToDevice(peer)
if err != nil { if err != nil {
logger.Info("Error adding peer: %v", err) logger.Info("Error adding peer: %v", err)
@@ -836,6 +831,8 @@ func (s *WireGuardService) handleUpdatePeer(msg websocket.WSMessage) {
return return
} }
s.holePunchManager.TriggerHolePunch()
// Parse the public key // Parse the public key
pubKey, err := wgtypes.ParseKey(request.PublicKey) pubKey, err := wgtypes.ParseKey(request.PublicKey)
if err != nil { if err != nil {
@@ -970,13 +967,7 @@ func (s *WireGuardService) calculatePeerBandwidth() ([]PeerBandwidth, error) {
// parse the public keys and have them as base64 in the opposite order to fixKey // parse the public keys and have them as base64 in the opposite order to fixKey
for i := range peerBandwidths { for i := range peerBandwidths {
pubKeyBytes, err := base64.StdEncoding.DecodeString(peerBandwidths[i].PublicKey) peerBandwidths[i].PublicKey = util.UnfixKey(peerBandwidths[i].PublicKey) // its in the long form but we need base64
if err != nil {
logger.Info("Failed to decode public key %s: %v", peerBandwidths[i].PublicKey, err)
continue
}
// Convert to hex
peerBandwidths[i].PublicKey = hex.EncodeToString(pubKeyBytes)
} }
return peerBandwidths, nil return peerBandwidths, nil
@@ -1037,7 +1028,7 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return fmt.Errorf("failed to calculate peer bandwidth: %v", err) return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
} }
err = s.client.SendMessage("newt/receive-bandwidth", map[string]interface{}{ err = s.client.SendMessageNoLog("newt/receive-bandwidth", map[string]interface{}{
"bandwidthData": bandwidths, "bandwidthData": bandwidths,
}) })
if err != nil { if err != nil {
@@ -1084,15 +1075,6 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
continue continue
} }
var rewriteTo netip.Prefix
if target.RewriteTo != "" {
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
continue
}
}
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
@@ -1101,7 +1083,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange) logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
} }
@@ -1210,15 +1192,6 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
continue continue
} }
var rewriteTo netip.Prefix
if target.RewriteTo != "" {
rewriteTo, err = netip.ParsePrefix(target.RewriteTo)
if err != nil {
logger.Info("Invalid CIDR %s: %v", target.RewriteTo, err)
continue
}
}
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
@@ -1227,7 +1200,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange) logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
} }
} }

View File

@@ -25,7 +25,7 @@ import (
const msgHealthFileWriteFailed = "Failed to write health file: %v" const msgHealthFileWriteFailed = "Failed to write health file: %v"
func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) { func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) {
logger.Debug("Pinging %s", dst) // logger.Debug("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst) socket, err := tnet.Dial("ping4", dst)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to create ICMP socket: %w", err) return 0, fmt.Errorf("failed to create ICMP socket: %w", err)
@@ -84,7 +84,7 @@ func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration,
latency := time.Since(start) latency := time.Since(start)
logger.Debug("Ping to %s successful, latency: %v", dst, latency) // logger.Debug("Ping to %s successful, latency: %v", dst, latency)
return latency, nil return latency, nil
} }
@@ -122,7 +122,7 @@ func reliablePing(tnet *netstack.Net, dst string, baseTimeout time.Duration, max
// If we get at least one success, we can return early for health checks // If we get at least one success, we can return early for health checks
if successCount > 0 { if successCount > 0 {
avgLatency := totalLatency / time.Duration(successCount) avgLatency := totalLatency / time.Duration(successCount)
logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency) // logger.Debug("Reliable ping succeeded after %d attempts, avg latency: %v", attempt, avgLatency)
return avgLatency, nil return avgLatency, nil
} }
} }

View File

@@ -38,7 +38,7 @@ type Manager struct {
sendHolepunchInterval time.Duration sendHolepunchInterval time.Duration
} }
const sendHolepunchIntervalMax = 60 * time.Second const sendHolepunchIntervalMax = 3 * time.Second
const sendHolepunchIntervalMin = 1 * time.Second const sendHolepunchIntervalMin = 1 * time.Second
// NewManager creates a new hole punch manager // NewManager creates a new hole punch manager
@@ -152,6 +152,28 @@ func (m *Manager) GetExitNodes() []ExitNode {
return nodes return nodes
} }
// ResetInterval resets the hole punch interval back to the minimum value,
// allowing it to climb back up through exponential backoff.
// This is useful when network conditions change or connectivity is restored.
func (m *Manager) ResetInterval() {
m.mu.Lock()
defer m.mu.Unlock()
if m.sendHolepunchInterval != sendHolepunchIntervalMin {
m.sendHolepunchInterval = sendHolepunchIntervalMin
logger.Info("Reset hole punch interval to minimum (%v)", sendHolepunchIntervalMin)
}
// Signal the goroutine to apply the new interval if running
if m.running && m.updateChan != nil {
select {
case m.updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes // TriggerHolePunch sends an immediate hole punch packet to all configured exit nodes
// This is useful for triggering hole punching on demand without waiting for the interval // This is useful for triggering hole punching on demand without waiting for the interval
func (m *Manager) TriggerHolePunch() error { func (m *Manager) TriggerHolePunch() error {
@@ -266,27 +288,6 @@ func (m *Manager) Start() error {
return nil return nil
} }
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes // runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes() { func (m *Manager) runMultipleExitNodes() {
defer func() { defer func() {
@@ -404,67 +405,6 @@ func (m *Manager) runMultipleExitNodes() {
} }
} }
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := util.ResolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
// Start with minimum interval
m.mu.Lock()
m.sendHolepunchInterval = sendHolepunchIntervalMin
m.mu.Unlock()
ticker := time.NewTicker(m.sendHolepunchInterval)
defer ticker.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
// Exponential backoff: double the interval up to max
m.mu.Lock()
newInterval := m.sendHolepunchInterval * 2
if newInterval > sendHolepunchIntervalMax {
newInterval = sendHolepunchIntervalMax
}
if newInterval != m.sendHolepunchInterval {
m.sendHolepunchInterval = newInterval
ticker.Reset(m.sendHolepunchInterval)
logger.Debug("Increased hole punch interval to %v", m.sendHolepunchInterval)
}
m.mu.Unlock()
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind // sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error { func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock() m.mu.Lock()

View File

@@ -140,16 +140,19 @@ func (t *HolepunchTester) Stop() {
// handleResponse is called by SharedBind when a magic response is received // handleResponse is called by SharedBind when a magic response is received
func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) { func (t *HolepunchTester) handleResponse(addr netip.AddrPort, echoData []byte) {
logger.Debug("Received magic response from %s", addr.String())
key := string(echoData) key := string(echoData)
value, ok := t.pendingRequests.LoadAndDelete(key) value, ok := t.pendingRequests.LoadAndDelete(key)
if !ok { if !ok {
// No matching request found // No matching request found
logger.Debug("No pending request found for magic response from %s", addr.String())
return return
} }
req := value.(*pendingRequest) req := value.(*pendingRequest)
rtt := time.Since(req.sentAt) rtt := time.Since(req.sentAt)
logger.Debug("Magic response matched pending request for %s (RTT: %v)", req.endpoint, rtt)
// Send RTT to the waiting goroutine (non-blocking) // Send RTT to the waiting goroutine (non-blocking)
select { select {

View File

@@ -3,6 +3,7 @@ package logger
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"sync" "sync"
"time" "time"
) )
@@ -139,6 +140,10 @@ type WireGuardLogger struct {
func (l *Logger) GetWireGuardLogger(prepend string) *WireGuardLogger { func (l *Logger) GetWireGuardLogger(prepend string) *WireGuardLogger {
return &WireGuardLogger{ return &WireGuardLogger{
Verbosef: func(format string, args ...any) { Verbosef: func(format string, args ...any) {
// if the format string contains "Sending keepalive packet", skip debug logging to reduce noise
if strings.Contains(format, "Sending keepalive packet") {
return
}
l.Debug(prepend+format, args...) l.Debug(prepend+format, args...)
}, },
Errorf: func(format string, args ...any) { Errorf: func(format string, args ...any) {

View File

@@ -1,9 +1,12 @@
package netstack2 package netstack2
import ( import (
"context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
@@ -26,14 +29,18 @@ type PortRange struct {
// SubnetRule represents a subnet with optional port restrictions and source address // SubnetRule represents a subnet with optional port restrictions and source address
// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed: // When RewriteTo is set, DNAT (Destination Network Address Translation) is performed:
// - Incoming packets: destination IP is rewritten to RewriteTo.Addr() // - Incoming packets: destination IP is rewritten to the resolved RewriteTo address
// - Outgoing packets: source IP is rewritten back to the original destination // - Outgoing packets: source IP is rewritten back to the original destination
// //
// RewriteTo can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32")
// - A domain name (e.g., "example.com") which will be resolved at request time
//
// This allows transparent proxying where traffic appears to come from the rewritten address // This allows transparent proxying where traffic appears to come from the rewritten address
type SubnetRule struct { type SubnetRule struct {
SourcePrefix netip.Prefix // Source IP prefix (who is sending) SourcePrefix netip.Prefix // Source IP prefix (who is sending)
DestPrefix netip.Prefix // Destination IP prefix (where it's going) DestPrefix netip.Prefix // Destination IP prefix (where it's going)
RewriteTo netip.Prefix // Optional rewrite address for DNAT (destination NAT) RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed PortRanges []PortRange // empty slice means all ports allowed
} }
@@ -58,7 +65,8 @@ func NewSubnetLookup() *SubnetLookup {
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) { // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
sl.mu.Lock() sl.mu.Lock()
defer sl.mu.Unlock() defer sl.mu.Unlock()
@@ -225,8 +233,9 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// AddSubnetRule adds a subnet with optional port restrictions to the proxy handler // AddSubnetRule adds a subnet with optional port restrictions to the proxy handler
// sourcePrefix: The IP prefix of the peer sending the data // sourcePrefix: The IP prefix of the peer sending the data
// destPrefix: The IP prefix of the destination // destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) { func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
if p == nil || !p.enabled { if p == nil || !p.enabled {
return return
} }
@@ -241,6 +250,43 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix) p.subnetLookup.RemoveSubnet(sourcePrefix, destPrefix)
} }
// resolveRewriteAddress resolves a rewrite address which can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
// - A domain name (e.g., "example.com") - performs DNS lookup at request time
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
if prefix, err := netip.ParsePrefix(rewriteTo); err == nil {
return prefix.Addr(), nil
}
// Try to parse as a plain IP address (e.g., "192.168.1.1")
if addr, err := netip.ParseAddr(rewriteTo); err == nil {
return addr, nil
}
// Not an IP address, treat as domain name and perform DNS lookup
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIP(ctx, "ip4", rewriteTo)
if err != nil {
return netip.Addr{}, fmt.Errorf("failed to resolve domain %s: %w", rewriteTo, err)
}
if len(ips) == 0 {
return netip.Addr{}, fmt.Errorf("no IP addresses found for domain %s", rewriteTo)
}
// Use the first resolved IP address
ip := ips[0]
if ip4 := ip.To4(); ip4 != nil {
return netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}), nil
}
return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)
}
// Initialize sets up the promiscuous NIC with the netTun's notification system // Initialize sets up the promiscuous NIC with the netTun's notification system
func (p *ProxyHandler) Initialize(notifiable channel.Notification) error { func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
if p == nil || !p.enabled { if p == nil || !p.enabled {
@@ -334,10 +380,20 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort) matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort)
if matchedRule != nil { if matchedRule != nil {
// Check if we need to perform DNAT // Check if we need to perform DNAT
if matchedRule.RewriteTo.IsValid() && matchedRule.RewriteTo.Addr().IsValid() { if matchedRule.RewriteTo != "" {
// Resolve the rewrite address (could be IP or domain)
newDst, err := p.resolveRewriteAddress(matchedRule.RewriteTo)
if err != nil {
// Failed to resolve, skip DNAT but still proxy the packet
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
// Perform DNAT - rewrite destination IP // Perform DNAT - rewrite destination IP
originalDst := dstAddr originalDst := dstAddr
newDst := matchedRule.RewriteTo.Addr()
// Create connection tracking key // Create connection tracking key
var srcPort uint16 var srcPort uint16

View File

@@ -350,7 +350,8 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler // AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix, rewriteTo netip.Prefix, portRanges []PortRange) { // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) {
tun := (*netTun)(net) tun := (*netTun)(net)
if tun.proxyHandler != nil { if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges)

View File

@@ -139,6 +139,18 @@ func FixKey(key string) string {
return hex.EncodeToString(decoded) return hex.EncodeToString(decoded)
} }
// this is the opposite of FixKey
func UnfixKey(hexKey string) string {
// Decode from hex
decoded, err := hex.DecodeString(hexKey)
if err != nil {
logger.Fatal("Error decoding hex: %v", err)
}
// Convert to base64
return base64.StdEncoding.EncodeToString(decoded)
}
func MapToWireGuardLogLevel(level logger.LogLevel) int { func MapToWireGuardLogLevel(level logger.LogLevel) int {
switch level { switch level {
case logger.DEBUG: case logger.DEBUG:

View File

@@ -206,6 +206,26 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return nil return nil
} }
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessageNoLog(messageType string, data interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
msg := WSMessage{
Type: messageType,
Data: data,
}
c.writeMux.Lock()
defer c.writeMux.Unlock()
if err := c.conn.WriteJSON(msg); err != nil {
return err
}
telemetry.IncWSMessage(c.metricsContext(), "out", "text")
return nil
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
stopChan := make(chan struct{}) stopChan := make(chan struct{})
go func() { go func() {