Update to use new packages

This commit is contained in:
Owen
2025-11-15 16:14:40 -05:00
parent 972c9a9760
commit c71c6e0b1a
9 changed files with 1314 additions and 291 deletions

378
bind/shared_bind.go Normal file
View File

@@ -0,0 +1,378 @@
//go:build !js
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"sync/atomic"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// Endpoint represents a network endpoint for the SharedBind
type Endpoint struct {
AddrPort netip.AddrPort
}
// ClearSrc implements the wgConn.Endpoint interface
func (e *Endpoint) ClearSrc() {}
// DstIP implements the wgConn.Endpoint interface
func (e *Endpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
// SrcIP implements the wgConn.Endpoint interface
func (e *Endpoint) SrcIP() netip.Addr {
return netip.Addr{}
}
// DstToBytes implements the wgConn.Endpoint interface
func (e *Endpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
// DstToString implements the wgConn.Endpoint interface
func (e *Endpoint) DstToString() string {
return e.AddrPort.String()
}
// SrcToString implements the wgConn.Endpoint interface
func (e *Endpoint) SrcToString() string {
return ""
}
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
// Reference counting to prevent closing while in use
refCount atomic.Int32
closed atomic.Bool
// Channels for receiving data
recvFuncs []wgConn.ReceiveFunc
// Port binding information
port uint16
}
// New creates a new SharedBind from an existing UDP connection.
// The SharedBind takes ownership of the connection and will close it
// when all references are released.
func New(udpConn *net.UDPConn) (*SharedBind, error) {
if udpConn == nil {
return nil, fmt.Errorf("udpConn cannot be nil")
}
bind := &SharedBind{
udpConn: udpConn,
}
// Initialize reference count to 1 (the creator holds the first reference)
bind.refCount.Store(1)
// Get the local port
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
bind.port = uint16(addr.Port)
}
return bind, nil
}
// AddRef increments the reference count. Call this when sharing
// the bind with another component.
func (b *SharedBind) AddRef() {
newCount := b.refCount.Add(1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
}
// Release decrements the reference count. When it reaches zero,
// the underlying UDP connection is closed.
func (b *SharedBind) Release() error {
newCount := b.refCount.Add(-1)
// Optional: Add logging for debugging
_ = newCount // Placeholder for potential logging
if newCount < 0 {
// This should never happen with proper usage
b.refCount.Store(0)
return fmt.Errorf("SharedBind reference count went negative")
}
if newCount == 0 {
return b.closeConnection()
}
return nil
}
// closeConnection actually closes the UDP connection
func (b *SharedBind) closeConnection() error {
if !b.closed.CompareAndSwap(false, true) {
// Already closed
return nil
}
b.mu.Lock()
defer b.mu.Unlock()
var err error
if b.udpConn != nil {
err = b.udpConn.Close()
b.udpConn = nil
}
b.ipv4PC = nil
b.ipv6PC = nil
return err
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
b.mu.RLock()
defer b.mu.RUnlock()
return b.udpConn
}
// GetRefCount returns the current reference count (for debugging)
func (b *SharedBind) GetRefCount() int32 {
return b.refCount.Load()
}
// IsClosed returns whether the bind is closed
func (b *SharedBind) IsClosed() bool {
return b.closed.Load()
}
// WriteToUDP writes data to a specific UDP address.
// This is thread-safe and can be used by hole punch senders.
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
return conn.WriteToUDP(data, addr)
}
// Close implements the WireGuard Bind interface.
// It decrements the reference count and closes the connection if no references remain.
func (b *SharedBind) Close() error {
return b.Release()
}
// Open implements the WireGuard Bind interface.
// Since the connection is already open, this just sets up the receive functions.
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
if b.closed.Load() {
return nil, 0, net.ErrClosed
}
b.mu.Lock()
defer b.mu.Unlock()
if b.udpConn == nil {
return nil, 0, net.ErrClosed
}
// Set up IPv4 and IPv6 packet connections for advanced features
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
}
// Create receive functions
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
// Add IPv4 receive function
if b.ipv4PC != nil || runtime.GOOS != "linux" {
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
}
// Add IPv6 receive function if needed
// For now, we focus on IPv4 for hole punching use case
b.recvFuncs = recvFuncs
return recvFuncs, b.port, nil
}
// makeReceiveIPv4 creates a receive function for IPv4 packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
if b.closed.Load() {
return 0, net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
// Fallback to simple read for other platforms
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}
// receiveIPv4Batch uses batch reading for better performance on Linux
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
// Create messages for batch reading
msgs := make([]ipv4.Message, len(bufs))
for i := range bufs {
msgs[i].Buffers = [][]byte{bufs[i]}
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
}
numMsgs, err := pc.ReadBatch(msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
sizes[i] = msgs[i].N
if sizes[i] == 0 {
continue
}
if msgs[i].Addr != nil {
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
addrPort := udpAddr.AddrPort()
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
}
}
return numMsgs, nil
}
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
n, addr, err := conn.ReadFromUDP(bufs[0])
if err != nil {
return 0, err
}
sizes[0] = n
if addr != nil {
addrPort := addr.AddrPort()
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
}
return 1, nil
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddr *net.UDPAddr
// Try to cast to StdNetEndpoint first
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
var addr netip.Addr
var port uint16
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
addr, _ = netip.AddrFromSlice(dstBytes[:16])
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
} else { // IPv4
addr, _ = netip.AddrFromSlice(dstBytes[:4])
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
}
if addr.IsValid() {
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
}
}
}
if destAddr == nil {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// SetMark implements the WireGuard Bind interface.
// It's a no-op for this implementation.
func (b *SharedBind) SetMark(mark uint32) error {
// Not implemented for this use case
return nil
}
// BatchSize returns the preferred batch size for sending packets.
func (b *SharedBind) BatchSize() int {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
return wgConn.IdealBatchSize
}
return 1
}
// ParseEndpoint creates a new endpoint from a string address.
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
addrPort, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
}

424
bind/shared_bind_test.go Normal file
View File

@@ -0,0 +1,424 @@
//go:build !js
package bind
import (
"net"
"net/netip"
"sync"
"testing"
"time"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// TestSharedBindCreation tests basic creation and initialization
func TestSharedBindCreation(t *testing.T) {
// Create a UDP connection
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
defer udpConn.Close()
// Create SharedBind
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
if bind == nil {
t.Fatal("SharedBind is nil")
}
// Verify initial reference count
if bind.refCount.Load() != 1 {
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
}
// Clean up
if err := bind.Close(); err != nil {
t.Errorf("Failed to close SharedBind: %v", err)
}
}
// TestSharedBindReferenceCount tests reference counting
func TestSharedBindReferenceCount(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add references
bind.AddRef()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
}
bind.AddRef()
if bind.refCount.Load() != 3 {
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
}
// Release references
bind.Release()
if bind.refCount.Load() != 2 {
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
}
bind.Release()
bind.Release() // This should close the connection
if !bind.closed.Load() {
t.Error("Expected bind to be closed after all references released")
}
}
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
func TestSharedBindWriteToUDP(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Send data
testData := []byte("Hello, SharedBind!")
n, err := senderBind.WriteToUDP(testData, receiverAddr)
if err != nil {
t.Fatalf("WriteToUDP failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err = receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindConcurrentWrites tests thread-safety
func TestSharedBindConcurrentWrites(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Launch concurrent writes
numGoroutines := 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
data := []byte{byte(id)}
_, err := senderBind.WriteToUDP(data, receiverAddr)
if err != nil {
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
}
}(i)
}
wg.Wait()
}
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
func TestSharedBindWireGuardInterface(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
// Test Open
recvFuncs, port, err := bind.Open(0)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
if len(recvFuncs) == 0 {
t.Error("Expected at least one receive function")
}
if port == 0 {
t.Error("Expected non-zero port")
}
// Test SetMark (should be a no-op)
if err := bind.SetMark(0); err != nil {
t.Errorf("SetMark failed: %v", err)
}
// Test BatchSize
batchSize := bind.BatchSize()
if batchSize <= 0 {
t.Error("Expected positive batch size")
}
}
// TestSharedBindSend tests the Send method with WireGuard endpoints
func TestSharedBindSend(t *testing.T) {
// Create sender
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create sender UDP connection: %v", err)
}
senderBind, err := New(senderConn)
if err != nil {
t.Fatalf("Failed to create sender SharedBind: %v", err)
}
defer senderBind.Close()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
// Create an endpoint
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
// Send data
testData := []byte("WireGuard packet")
bufs := [][]byte{testData}
err = senderBind.Send(bufs, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// Receive data
buf := make([]byte, 1024)
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, _, err := receiverConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive data: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
}
}
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
func TestSharedBindMultipleUsers(t *testing.T) {
// Create shared bind
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
sharedBind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
// Add reference for hole punch sender
sharedBind.AddRef()
// Create receiver
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create receiver UDP connection: %v", err)
}
defer receiverConn.Close()
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
var wg sync.WaitGroup
// Simulate WireGuard using the bind
wg.Add(1)
go func() {
defer wg.Done()
addrPort := receiverAddr.AddrPort()
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
for i := 0; i < 10; i++ {
data := []byte("WireGuard packet")
bufs := [][]byte{data}
if err := sharedBind.Send(bufs, endpoint); err != nil {
t.Errorf("WireGuard Send failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
// Simulate hole punch sender using the bind
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
data := []byte("Hole punch packet")
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
t.Errorf("Hole punch WriteToUDP failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
}()
wg.Wait()
// Release the hole punch reference
sharedBind.Release()
// Close WireGuard's reference (should close the connection)
sharedBind.Close()
if !sharedBind.closed.Load() {
t.Error("Expected bind to be closed after all users released it")
}
}
// TestEndpoint tests the Endpoint implementation
func TestEndpoint(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
addrPort := netip.AddrPortFrom(addr, 51820)
ep := &Endpoint{AddrPort: addrPort}
// Test DstIP
if ep.DstIP() != addr {
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
}
// Test DstToString
expected := "192.168.1.1:51820"
if ep.DstToString() != expected {
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
}
// Test DstToBytes
bytes := ep.DstToBytes()
if len(bytes) == 0 {
t.Error("Expected DstToBytes to return non-empty slice")
}
// Test SrcIP (should be zero)
if ep.SrcIP().IsValid() {
t.Error("Expected SrcIP to be invalid")
}
// Test ClearSrc (should not panic)
ep.ClearSrc()
}
// TestParseEndpoint tests the ParseEndpoint method
func TestParseEndpoint(t *testing.T) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create UDP connection: %v", err)
}
bind, err := New(udpConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer bind.Close()
tests := []struct {
name string
input string
wantErr bool
checkAddr func(*testing.T, wgConn.Endpoint)
}{
{
name: "valid IPv4",
input: "192.168.1.1:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "192.168.1.1:51820" {
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
}
},
},
{
name: "valid IPv6",
input: "[::1]:51820",
wantErr: false,
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
if ep.DstToString() != "[::1]:51820" {
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
}
},
},
{
name: "invalid - missing port",
input: "192.168.1.1",
wantErr: true,
},
{
name: "invalid - bad format",
input: "not-an-address",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ep, err := bind.ParseEndpoint(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && tt.checkAddr != nil {
tt.checkAddr(t, ep)
}
})
}
}

View File

@@ -7,7 +7,6 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os"
"os/exec"
"strings"
@@ -398,57 +397,6 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
}
}
func resolveDomain(domain string) (string, error) {
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Remove any protocol prefix if present
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
// if there are any trailing slashes, remove them
host = strings.TrimSuffix(host, "/")
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)

16
go.mod
View File

@@ -17,9 +17,9 @@ require (
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
golang.org/x/crypto v0.43.0
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
golang.org/x/net v0.46.0
golang.org/x/crypto v0.44.0
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
golang.org/x/net v0.47.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
google.golang.org/grpc v1.76.0
@@ -69,12 +69,12 @@ require (
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/mod v0.28.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.37.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.37.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect

15
go.sum
View File

@@ -107,32 +107,47 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=

347
holepunch/holepunch.go Normal file
View File

@@ -0,0 +1,347 @@
package holepunch
import (
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.org/x/exp/rand"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// ExitNode represents a WireGuard exit node for hole punching
type ExitNode struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
}
// Manager handles UDP hole punching operations
type Manager struct {
mu sync.Mutex
running bool
stopChan chan struct{}
sharedBind *bind.SharedBind
newtID string
token string
}
// NewManager creates a new hole punch manager
func NewManager(sharedBind *bind.SharedBind, newtID string) *Manager {
return &Manager{
sharedBind: sharedBind,
newtID: newtID,
}
}
// SetToken updates the authentication token used for hole punching
func (m *Manager) SetToken(token string) {
m.mu.Lock()
defer m.mu.Unlock()
m.token = token
}
// IsRunning returns whether hole punching is currently active
func (m *Manager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.running
}
// Stop stops any ongoing hole punch operations
func (m *Manager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
if m.stopChan != nil {
close(m.stopChan)
m.stopChan = nil
}
m.running = false
logger.Info("Hole punch manager stopped")
}
// StartMultipleExitNodes starts hole punching to multiple exit nodes
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
if len(exitNodes) == 0 {
m.mu.Unlock()
logger.Warn("No exit nodes provided for hole punching")
return fmt.Errorf("no exit nodes provided")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
go m.runMultipleExitNodes(exitNodes)
return nil
}
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
m.mu.Lock()
if m.running {
m.mu.Unlock()
logger.Debug("UDP hole punch already running, skipping new request")
return fmt.Errorf("hole punch already running")
}
m.running = true
m.stopChan = make(chan struct{})
m.mu.Unlock()
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
go m.runSingleEndpoint(endpoint, serverPubKey)
return nil
}
// runMultipleExitNodes performs hole punching to multiple exit nodes
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for all exit nodes")
}()
// Resolve all endpoints upfront
type resolvedExitNode struct {
remoteAddr *net.UDPAddr
publicKey string
endpointName string
}
var resolvedNodes []resolvedExitNode
for _, exitNode := range exitNodes {
host, err := util.ResolveDomain(exitNode.Endpoint)
if err != nil {
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
continue
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
continue
}
resolvedNodes = append(resolvedNodes, resolvedExitNode{
remoteAddr: remoteAddr,
publicKey: exitNode.PublicKey,
endpointName: exitNode.Endpoint,
})
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
}
if len(resolvedNodes) == 0 {
logger.Error("No exit nodes could be resolved")
return
}
// Send initial hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
}
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
// Send hole punch to all exit nodes
for _, node := range resolvedNodes {
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
}
}
}
}
}
// runSingleEndpoint performs hole punching to a single endpoint
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
defer func() {
m.mu.Lock()
m.running = false
m.mu.Unlock()
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
}()
host, err := util.ResolveDomain(endpoint)
if err != nil {
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
return
}
serverAddr := net.JoinHostPort(host, "21820")
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
if err != nil {
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
return
}
// Execute once immediately before starting the loop
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Warn("Failed to send initial hole punch: %v", err)
}
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-m.stopChan:
logger.Debug("Hole punch stopped by signal")
return
case <-timeout.C:
logger.Debug("Hole punch timeout reached")
return
case <-ticker.C:
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
logger.Debug("Failed to send hole punch: %v", err)
}
}
}
}
// sendHolePunch sends an encrypted hole punch packet using the shared bind
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
m.mu.Lock()
token := m.token
newtID := m.newtID
m.mu.Unlock()
if serverPubKey == "" || token == "" {
return fmt.Errorf("server public key or OLM token is empty")
}
payload := struct {
NewtID string `json:"newtId"`
Token string `json:"token"`
}{
NewtID: newtID,
Token: token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %w", err)
}
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
}
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
if err != nil {
return fmt.Errorf("failed to write to UDP: %w", err)
}
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
return nil
}
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/updates"
"github.com/fosrl/newt/util"
"github.com/fosrl/newt/websocket"
"github.com/fosrl/newt/internal/state"
@@ -663,7 +664,7 @@ func main() {
logger.Info("Connecting to endpoint: %s", host)
endpoint, err := resolveDomain(wgData.Endpoint)
endpoint, err := util.ResolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
regResult = "failure"

58
util/util.go Normal file
View File

@@ -0,0 +1,58 @@
package util
import (
"fmt"
"net"
"strings"
)
func ResolveDomain(domain string) (string, error) {
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Remove any protocol prefix if present
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
// if there are any trailing slashes, remove them
host = strings.TrimSuffix(host, "/")
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}

View File

@@ -2,7 +2,6 @@ package wgnetstack
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -16,14 +15,12 @@ import (
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/holepunch"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/netstack2"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/curve25519"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -66,22 +63,20 @@ type PeerReading struct {
}
type WireGuardService struct {
interfaceName string
mtu int
client *websocket.Client
config WgConfig
key wgtypes.Key
keyFilePath string
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
stopHolepunch chan struct{}
host string
serverPubKey string
holePunchEndpoint string
token string
stopGetConfig func()
interfaceName string
mtu int
client *websocket.Client
config WgConfig
key wgtypes.Key
keyFilePath string
newtId string
lastReadings map[string]PeerReading
mu sync.Mutex
Port uint16
host string
serverPubKey string
token string
stopGetConfig func()
// Netstack fields
tun tun.Device
tnet *netstack2.Net
@@ -95,6 +90,9 @@ type WireGuardService struct {
// Proxy manager for tunnel
proxyManager *proxy.ProxyManager
TunnelIP string
// Shared bind and holepunch manager
sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager
}
// GetProxyManager returns the proxy manager for this WireGuardService
@@ -118,24 +116,6 @@ func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) e
return s.proxyManager.RemoveTarget(proto, listenIP, port)
}
// Add this type definition
type fixedPortBind struct {
port uint16
conn.Bind
}
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
// Ignore the requested port and use our fixed port
return b.Bind.Open(b.port)
}
func NewFixedPortBind(port uint16) conn.Bind {
return &fixedPortBind{
port: port,
Bind: conn.NewDefaultBind(),
}
}
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
if maxPort < minPort {
@@ -215,6 +195,28 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
return nil, fmt.Errorf("error finding available port: %v", err)
}
// Create shared UDP socket for both holepunch and WireGuard
localAddr := &net.UDPAddr{
Port: int(port),
IP: net.IPv4zero,
}
udpConn, err := net.ListenUDP("udp", localAddr)
if err != nil {
return nil, fmt.Errorf("failed to create UDP socket: %v", err)
}
sharedBind, err := bind.New(udpConn)
if err != nil {
udpConn.Close()
return nil, fmt.Errorf("failed to create shared bind: %v", err)
}
// Add a reference for the hole punch manager (creator already has one reference for WireGuard)
sharedBind.AddRef()
logger.Info("Created shared UDP socket on port %d (refcount: %d)", port, sharedBind.GetRefCount())
// Parse DNS addresses
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
@@ -227,12 +229,16 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
newtId: newtId,
host: host,
lastReadings: make(map[string]PeerReading),
stopHolepunch: make(chan struct{}),
Port: port,
dns: dnsAddrs,
proxyManager: proxy.NewProxyManagerWithoutTNet(),
sharedBind: sharedBind,
}
// Create the holepunch manager with ResolveDomain function
// We'll need to pass a domain resolver function
service.holePunchManager = holepunch.NewManager(sharedBind, newtId)
// Register websocket handlers
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
@@ -344,10 +350,15 @@ func (s *WireGuardService) Close(rm bool) {
s.stopGetConfig = nil
}
// Stop hole punch manager
if s.holePunchManager != nil {
s.holePunchManager.Stop()
}
s.mu.Lock()
defer s.mu.Unlock()
// Close WireGuard device first - this will automatically close the TUN device
// Close WireGuard device first - this will call sharedBind.Close() which releases WireGuard's reference
if s.device != nil {
s.device.Close()
s.device = nil
@@ -360,28 +371,22 @@ func (s *WireGuardService) Close(rm bool) {
if s.tun != nil {
s.tun = nil // Don't call tun.Close() here since device.Close() already closed it
}
}
func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) {
// if the device is already created dont start a new holepunch
if s.device != nil {
return
// Release the hole punch reference to the shared bind
if s.sharedBind != nil {
// Release hole punch reference (WireGuard already released its reference via device.Close())
logger.Debug("Releasing shared bind (refcount before release: %d)", s.sharedBind.GetRefCount())
s.sharedBind.Release()
s.sharedBind = nil
logger.Info("Released shared UDP bind")
}
s.serverPubKey = serverPubKey
s.holePunchEndpoint = endpoint
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
// Create a new stop channel for this holepunch session
s.stopHolepunch = make(chan struct{})
// start the UDP holepunch
go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
}
func (s *WireGuardService) SetToken(token string) {
s.token = token
if s.holePunchManager != nil {
s.holePunchManager.SetToken(token)
}
}
// GetNetstackNet returns the netstack network interface for use by other components
@@ -412,6 +417,19 @@ func (s *WireGuardService) SetOnNetstackClose(callback func()) {
s.onNetstackClose = callback
}
// StartHolepunch starts hole punching to a specific endpoint
func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
if s.holePunchManager == nil {
logger.Warn("Hole punch manager not initialized")
return
}
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil {
logger.Warn("Failed to start hole punch: %v", err)
}
}
func (s *WireGuardService) LoadRemoteConfig() error {
if s.stopGetConfig != nil {
s.stopGetConfig()
@@ -485,10 +503,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// Parse the IP address and CIDR mask
tunnelIP := netip.MustParseAddr(parts[0])
// stop the holepunch its a channel
if s.stopHolepunch != nil {
close(s.stopHolepunch)
s.stopHolepunch = nil
// Stop any ongoing hole punch operations
if s.holePunchManager != nil {
s.holePunchManager.Stop()
}
// Parse the IP address from the config
@@ -512,8 +529,8 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
// s.proxyManager.SetTNet(s.tnet)
s.TunnelIP = tunnelIP.String()
// Create WireGuard device
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
// Create WireGuard device using the shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent, // Use silent logging by default - could be made configurable
"wireguard: ",
))
@@ -946,171 +963,6 @@ func (s *WireGuardService) reportPeerBandwidth() error {
return nil
}
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
if s.serverPubKey == "" || s.token == "" {
logger.Debug("Server public key or token not set, skipping UDP hole punch")
return nil
}
// Parse server address
serverSplit := strings.Split(serverAddr, ":")
if len(serverSplit) < 2 {
return fmt.Errorf("invalid server address format, expected hostname:port")
}
serverHostname := serverSplit[0]
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
if err != nil {
return fmt.Errorf("failed to parse server port: %v", err)
}
// Resolve server hostname to IP
serverIPAddr := network.HostToAddr(serverHostname)
if serverIPAddr == nil {
return fmt.Errorf("failed to resolve server hostname")
}
// Create local UDP address using the same port as WireGuard
localAddr := &net.UDPAddr{
IP: net.IPv4zero,
Port: int(s.Port),
}
// Create remote server address
remoteAddr := &net.UDPAddr{
IP: serverIPAddr.IP,
Port: int(serverPort),
}
// Create UDP connection bound to the same port as WireGuard
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
if err != nil {
return fmt.Errorf("failed to create netstack UDP connection: %v", err)
}
defer conn.Close()
// Create JSON payload
payload := struct {
NewtID string `json:"newtId"`
Token string `json:"token"`
}{
NewtID: s.newtId,
Token: s.token,
}
// Convert payload to JSON
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %v", err)
}
// Encrypt the payload using the server's WireGuard public key
encryptedPayload, err := s.encryptPayload(payloadBytes)
if err != nil {
return fmt.Errorf("failed to encrypt payload: %v", err)
}
// Convert encrypted payload to JSON
jsonData, err := json.Marshal(encryptedPayload)
if err != nil {
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
}
// Send the encrypted packet using the netstack UDP connection
_, err = conn.Write(jsonData)
if err != nil {
return fmt.Errorf("failed to send UDP packet: %v", err)
}
logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String())
return nil
}
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
// Generate an ephemeral keypair for this message
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
}
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
// Parse the server's public key
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
if err != nil {
return nil, fmt.Errorf("failed to parse server public key: %v", err)
}
// Use X25519 for key exchange (replacing deprecated ScalarMult)
var ephPrivKeyFixed [32]byte
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
// Perform X25519 key exchange
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
if err != nil {
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
}
// Create an AEAD cipher using the shared secret
aead, err := chacha20poly1305.New(sharedSecret)
if err != nil {
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
}
// Generate a random nonce
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %v", err)
}
// Encrypt the payload
ciphertext := aead.Seal(nil, nonce, payload, nil)
// Prepare the final encrypted message
encryptedMsg := struct {
EphemeralPublicKey string `json:"ephemeralPublicKey"`
Nonce []byte `json:"nonce"`
Ciphertext []byte `json:"ciphertext"`
}{
EphemeralPublicKey: ephemeralPublicKey.String(),
Nonce: nonce,
Ciphertext: ciphertext,
}
return encryptedMsg, nil
}
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
logger.Info("Starting UDP hole punch routine to %s:21820", host)
// send initial hole punch
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Debug("Failed to send initial UDP hole punch: %v", err)
}
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
timeout := time.NewTimer(15 * time.Second)
defer timeout.Stop()
for {
select {
case <-s.stopHolepunch:
logger.Info("Stopping UDP holepunch")
return
case <-timeout.C:
logger.Info("UDP holepunch routine timed out after 15 seconds")
return
case <-ticker.C:
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
logger.Debug("Failed to send UDP hole punch: %v", err)
}
}
}
}
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
var replace = false
for _, t := range targetData.Targets {
@@ -1242,8 +1094,8 @@ func (s *WireGuardService) ReplaceNetstack() error {
s.tun = newTun
s.tnet = newTnet
// Create new WireGuard device with same port
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
// Create new WireGuard device with same shared bind
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
device.LogLevelSilent,
"wireguard: ",
))