mirror of
https://github.com/fosrl/newt.git
synced 2026-02-19 19:36:39 +00:00
Update to use new packages
This commit is contained in:
378
bind/shared_bind.go
Normal file
378
bind/shared_bind.go
Normal file
@@ -0,0 +1,378 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// Endpoint represents a network endpoint for the SharedBind
|
||||
type Endpoint struct {
|
||||
AddrPort netip.AddrPort
|
||||
}
|
||||
|
||||
// ClearSrc implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) ClearSrc() {}
|
||||
|
||||
// DstIP implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstIP() netip.Addr {
|
||||
return e.AddrPort.Addr()
|
||||
}
|
||||
|
||||
// SrcIP implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) SrcIP() netip.Addr {
|
||||
return netip.Addr{}
|
||||
}
|
||||
|
||||
// DstToBytes implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstToBytes() []byte {
|
||||
b, _ := e.AddrPort.MarshalBinary()
|
||||
return b
|
||||
}
|
||||
|
||||
// DstToString implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) DstToString() string {
|
||||
return e.AddrPort.String()
|
||||
}
|
||||
|
||||
// SrcToString implements the wgConn.Endpoint interface
|
||||
func (e *Endpoint) SrcToString() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||
// and hole punch senders. It wraps a single UDP connection and implements
|
||||
// reference counting to prevent premature closure.
|
||||
type SharedBind struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying UDP connection
|
||||
udpConn *net.UDPConn
|
||||
|
||||
// IPv4 and IPv6 packet connections for advanced features
|
||||
ipv4PC *ipv4.PacketConn
|
||||
ipv6PC *ipv6.PacketConn
|
||||
|
||||
// Reference counting to prevent closing while in use
|
||||
refCount atomic.Int32
|
||||
closed atomic.Bool
|
||||
|
||||
// Channels for receiving data
|
||||
recvFuncs []wgConn.ReceiveFunc
|
||||
|
||||
// Port binding information
|
||||
port uint16
|
||||
}
|
||||
|
||||
// New creates a new SharedBind from an existing UDP connection.
|
||||
// The SharedBind takes ownership of the connection and will close it
|
||||
// when all references are released.
|
||||
func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||
if udpConn == nil {
|
||||
return nil, fmt.Errorf("udpConn cannot be nil")
|
||||
}
|
||||
|
||||
bind := &SharedBind{
|
||||
udpConn: udpConn,
|
||||
}
|
||||
|
||||
// Initialize reference count to 1 (the creator holds the first reference)
|
||||
bind.refCount.Store(1)
|
||||
|
||||
// Get the local port
|
||||
if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok {
|
||||
bind.port = uint16(addr.Port)
|
||||
}
|
||||
|
||||
return bind, nil
|
||||
}
|
||||
|
||||
// AddRef increments the reference count. Call this when sharing
|
||||
// the bind with another component.
|
||||
func (b *SharedBind) AddRef() {
|
||||
newCount := b.refCount.Add(1)
|
||||
// Optional: Add logging for debugging
|
||||
_ = newCount // Placeholder for potential logging
|
||||
}
|
||||
|
||||
// Release decrements the reference count. When it reaches zero,
|
||||
// the underlying UDP connection is closed.
|
||||
func (b *SharedBind) Release() error {
|
||||
newCount := b.refCount.Add(-1)
|
||||
// Optional: Add logging for debugging
|
||||
_ = newCount // Placeholder for potential logging
|
||||
|
||||
if newCount < 0 {
|
||||
// This should never happen with proper usage
|
||||
b.refCount.Store(0)
|
||||
return fmt.Errorf("SharedBind reference count went negative")
|
||||
}
|
||||
|
||||
if newCount == 0 {
|
||||
return b.closeConnection()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeConnection actually closes the UDP connection
|
||||
func (b *SharedBind) closeConnection() error {
|
||||
if !b.closed.CompareAndSwap(false, true) {
|
||||
// Already closed
|
||||
return nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err error
|
||||
if b.udpConn != nil {
|
||||
err = b.udpConn.Close()
|
||||
b.udpConn = nil
|
||||
}
|
||||
|
||||
b.ipv4PC = nil
|
||||
b.ipv6PC = nil
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetUDPConn returns the underlying UDP connection.
|
||||
// The caller must not close this connection directly.
|
||||
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
return b.udpConn
|
||||
}
|
||||
|
||||
// GetRefCount returns the current reference count (for debugging)
|
||||
func (b *SharedBind) GetRefCount() int32 {
|
||||
return b.refCount.Load()
|
||||
}
|
||||
|
||||
// IsClosed returns whether the bind is closed
|
||||
func (b *SharedBind) IsClosed() bool {
|
||||
return b.closed.Load()
|
||||
}
|
||||
|
||||
// WriteToUDP writes data to a specific UDP address.
|
||||
// This is thread-safe and can be used by hole punch senders.
|
||||
func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
return conn.WriteToUDP(data, addr)
|
||||
}
|
||||
|
||||
// Close implements the WireGuard Bind interface.
|
||||
// It decrements the reference count and closes the connection if no references remain.
|
||||
func (b *SharedBind) Close() error {
|
||||
return b.Release()
|
||||
}
|
||||
|
||||
// Open implements the WireGuard Bind interface.
|
||||
// Since the connection is already open, this just sets up the receive functions.
|
||||
func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
if b.closed.Load() {
|
||||
return nil, 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.udpConn == nil {
|
||||
return nil, 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Set up IPv4 and IPv6 packet connections for advanced features
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||
}
|
||||
|
||||
// Create receive functions
|
||||
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||
|
||||
// Add IPv4 receive function
|
||||
if b.ipv4PC != nil || runtime.GOOS != "linux" {
|
||||
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
|
||||
}
|
||||
|
||||
// Add IPv6 receive function if needed
|
||||
// For now, we focus on IPv4 for hole punching use case
|
||||
|
||||
b.recvFuncs = recvFuncs
|
||||
return recvFuncs, b.port, nil
|
||||
}
|
||||
|
||||
// makeReceiveIPv4 creates a receive function for IPv4 packets
|
||||
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
if b.closed.Load() {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
pc := b.ipv4PC
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Use batch reading on Linux for performance
|
||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||
}
|
||||
|
||||
// Fallback to simple read for other platforms
|
||||
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
// Create messages for batch reading
|
||||
msgs := make([]ipv4.Message, len(bufs))
|
||||
for i := range bufs {
|
||||
msgs[i].Buffers = [][]byte{bufs[i]}
|
||||
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
|
||||
}
|
||||
|
||||
numMsgs, err := pc.ReadBatch(msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
sizes[i] = msgs[i].N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if msgs[i].Addr != nil {
|
||||
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
|
||||
addrPort := udpAddr.AddrPort()
|
||||
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms
|
||||
func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||
n, addr, err := conn.ReadFromUDP(bufs[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sizes[0] = n
|
||||
if addr != nil {
|
||||
addrPort := addr.AddrPort()
|
||||
eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
// Send implements the WireGuard Bind interface.
|
||||
// It sends packets to the specified endpoint.
|
||||
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
if b.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Extract the destination address from the endpoint
|
||||
var destAddr *net.UDPAddr
|
||||
|
||||
// Try to cast to StdNetEndpoint first
|
||||
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
|
||||
} else {
|
||||
// Fallback: construct from DstIP and DstToBytes
|
||||
dstBytes := ep.DstToBytes()
|
||||
if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes)
|
||||
var addr netip.Addr
|
||||
var port uint16
|
||||
|
||||
if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes)
|
||||
addr, _ = netip.AddrFromSlice(dstBytes[:16])
|
||||
port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8
|
||||
} else { // IPv4
|
||||
addr, _ = netip.AddrFromSlice(dstBytes[:4])
|
||||
port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if destAddr == nil {
|
||||
return fmt.Errorf("could not extract destination address from endpoint")
|
||||
}
|
||||
|
||||
// Send all buffers to the destination
|
||||
for _, buf := range bufs {
|
||||
_, err := conn.WriteToUDP(buf, destAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMark implements the WireGuard Bind interface.
|
||||
// It's a no-op for this implementation.
|
||||
func (b *SharedBind) SetMark(mark uint32) error {
|
||||
// Not implemented for this use case
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchSize returns the preferred batch size for sending packets.
|
||||
func (b *SharedBind) BatchSize() int {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
return wgConn.IdealBatchSize
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// ParseEndpoint creates a new endpoint from a string address.
|
||||
func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
||||
addrPort, err := netip.ParseAddrPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil
|
||||
}
|
||||
424
bind/shared_bind_test.go
Normal file
424
bind/shared_bind_test.go
Normal file
@@ -0,0 +1,424 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// TestSharedBindCreation tests basic creation and initialization
|
||||
func TestSharedBindCreation(t *testing.T) {
|
||||
// Create a UDP connection
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
// Create SharedBind
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
if bind == nil {
|
||||
t.Fatal("SharedBind is nil")
|
||||
}
|
||||
|
||||
// Verify initial reference count
|
||||
if bind.refCount.Load() != 1 {
|
||||
t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if err := bind.Close(); err != nil {
|
||||
t.Errorf("Failed to close SharedBind: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindReferenceCount tests reference counting
|
||||
func TestSharedBindReferenceCount(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
// Add references
|
||||
bind.AddRef()
|
||||
if bind.refCount.Load() != 2 {
|
||||
t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
bind.AddRef()
|
||||
if bind.refCount.Load() != 3 {
|
||||
t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
// Release references
|
||||
bind.Release()
|
||||
if bind.refCount.Load() != 2 {
|
||||
t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load())
|
||||
}
|
||||
|
||||
bind.Release()
|
||||
bind.Release() // This should close the connection
|
||||
|
||||
if !bind.closed.Load() {
|
||||
t.Error("Expected bind to be closed after all references released")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindWriteToUDP tests the WriteToUDP functionality
|
||||
func TestSharedBindWriteToUDP(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Send data
|
||||
testData := []byte("Hello, SharedBind!")
|
||||
n, err := senderBind.WriteToUDP(testData, receiverAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteToUDP failed: %v", err)
|
||||
}
|
||||
|
||||
if n != len(testData) {
|
||||
t.Errorf("Expected to send %d bytes, sent %d", len(testData), n)
|
||||
}
|
||||
|
||||
// Receive data
|
||||
buf := make([]byte, 1024)
|
||||
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, _, err = receiverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive data: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindConcurrentWrites tests thread-safety
|
||||
func TestSharedBindConcurrentWrites(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Launch concurrent writes
|
||||
numGoroutines := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
data := []byte{byte(id)}
|
||||
_, err := senderBind.WriteToUDP(data, receiverAddr)
|
||||
if err != nil {
|
||||
t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation
|
||||
func TestSharedBindWireGuardInterface(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer bind.Close()
|
||||
|
||||
// Test Open
|
||||
recvFuncs, port, err := bind.Open(0)
|
||||
if err != nil {
|
||||
t.Fatalf("Open failed: %v", err)
|
||||
}
|
||||
|
||||
if len(recvFuncs) == 0 {
|
||||
t.Error("Expected at least one receive function")
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
t.Error("Expected non-zero port")
|
||||
}
|
||||
|
||||
// Test SetMark (should be a no-op)
|
||||
if err := bind.SetMark(0); err != nil {
|
||||
t.Errorf("SetMark failed: %v", err)
|
||||
}
|
||||
|
||||
// Test BatchSize
|
||||
batchSize := bind.BatchSize()
|
||||
if batchSize <= 0 {
|
||||
t.Error("Expected positive batch size")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindSend tests the Send method with WireGuard endpoints
|
||||
func TestSharedBindSend(t *testing.T) {
|
||||
// Create sender
|
||||
senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender UDP connection: %v", err)
|
||||
}
|
||||
|
||||
senderBind, err := New(senderConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create sender SharedBind: %v", err)
|
||||
}
|
||||
defer senderBind.Close()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
// Create an endpoint
|
||||
addrPort := receiverAddr.AddrPort()
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
|
||||
// Send data
|
||||
testData := []byte("WireGuard packet")
|
||||
bufs := [][]byte{testData}
|
||||
err = senderBind.Send(bufs, endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
// Receive data
|
||||
buf := make([]byte, 1024)
|
||||
receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, _, err := receiverConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive data: %v", err)
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
t.Errorf("Expected to receive %q, got %q", testData, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind
|
||||
func TestSharedBindMultipleUsers(t *testing.T) {
|
||||
// Create shared bind
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
|
||||
// Add reference for hole punch sender
|
||||
sharedBind.AddRef()
|
||||
|
||||
// Create receiver
|
||||
receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create receiver UDP connection: %v", err)
|
||||
}
|
||||
defer receiverConn.Close()
|
||||
|
||||
receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Simulate WireGuard using the bind
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
addrPort := receiverAddr.AddrPort()
|
||||
endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("WireGuard packet")
|
||||
bufs := [][]byte{data}
|
||||
if err := sharedBind.Send(bufs, endpoint); err != nil {
|
||||
t.Errorf("WireGuard Send failed: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
// Simulate hole punch sender using the bind
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte("Hole punch packet")
|
||||
if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil {
|
||||
t.Errorf("Hole punch WriteToUDP failed: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Release the hole punch reference
|
||||
sharedBind.Release()
|
||||
|
||||
// Close WireGuard's reference (should close the connection)
|
||||
sharedBind.Close()
|
||||
|
||||
if !sharedBind.closed.Load() {
|
||||
t.Error("Expected bind to be closed after all users released it")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEndpoint tests the Endpoint implementation
|
||||
func TestEndpoint(t *testing.T) {
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
addrPort := netip.AddrPortFrom(addr, 51820)
|
||||
|
||||
ep := &Endpoint{AddrPort: addrPort}
|
||||
|
||||
// Test DstIP
|
||||
if ep.DstIP() != addr {
|
||||
t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP())
|
||||
}
|
||||
|
||||
// Test DstToString
|
||||
expected := "192.168.1.1:51820"
|
||||
if ep.DstToString() != expected {
|
||||
t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString())
|
||||
}
|
||||
|
||||
// Test DstToBytes
|
||||
bytes := ep.DstToBytes()
|
||||
if len(bytes) == 0 {
|
||||
t.Error("Expected DstToBytes to return non-empty slice")
|
||||
}
|
||||
|
||||
// Test SrcIP (should be zero)
|
||||
if ep.SrcIP().IsValid() {
|
||||
t.Error("Expected SrcIP to be invalid")
|
||||
}
|
||||
|
||||
// Test ClearSrc (should not panic)
|
||||
ep.ClearSrc()
|
||||
}
|
||||
|
||||
// TestParseEndpoint tests the ParseEndpoint method
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create UDP connection: %v", err)
|
||||
}
|
||||
|
||||
bind, err := New(udpConn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create SharedBind: %v", err)
|
||||
}
|
||||
defer bind.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
checkAddr func(*testing.T, wgConn.Endpoint)
|
||||
}{
|
||||
{
|
||||
name: "valid IPv4",
|
||||
input: "192.168.1.1:51820",
|
||||
wantErr: false,
|
||||
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||
if ep.DstToString() != "192.168.1.1:51820" {
|
||||
t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid IPv6",
|
||||
input: "[::1]:51820",
|
||||
wantErr: false,
|
||||
checkAddr: func(t *testing.T, ep wgConn.Endpoint) {
|
||||
if ep.DstToString() != "[::1]:51820" {
|
||||
t.Errorf("Expected [::1]:51820, got %s", ep.DstToString())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid - missing port",
|
||||
input: "192.168.1.1",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid - bad format",
|
||||
input: "not-an-address",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ep, err := bind.ParseEndpoint(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && tt.checkAddr != nil {
|
||||
tt.checkAddr(t, ep)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
@@ -398,57 +397,6 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
||||
}
|
||||
}
|
||||
|
||||
func resolveDomain(domain string) (string, error) {
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Remove any protocol prefix if present
|
||||
if strings.HasPrefix(host, "http://") {
|
||||
host = strings.TrimPrefix(host, "http://")
|
||||
} else if strings.HasPrefix(host, "https://") {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
|
||||
func parseTargetData(data interface{}) (TargetData, error) {
|
||||
var targetData TargetData
|
||||
jsonData, err := json.Marshal(data)
|
||||
16
go.mod
16
go.mod
@@ -17,9 +17,9 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.38.0
|
||||
go.opentelemetry.io/otel/sdk v1.38.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.44.0
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
|
||||
golang.org/x/net v0.47.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
google.golang.org/grpc v1.76.0
|
||||
@@ -69,12 +69,12 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/mod v0.28.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
golang.org/x/tools v0.37.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
|
||||
15
go.sum
15
go.sum
@@ -107,32 +107,47 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
|
||||
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
|
||||
golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U=
|
||||
golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE=
|
||||
golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||
|
||||
347
holepunch/holepunch.go
Normal file
347
holepunch/holepunch.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package holepunch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/exp/rand"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// ExitNode represents a WireGuard exit node for hole punching
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
// Manager handles UDP hole punching operations
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
stopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
newtID string
|
||||
token string
|
||||
}
|
||||
|
||||
// NewManager creates a new hole punch manager
|
||||
func NewManager(sharedBind *bind.SharedBind, newtID string) *Manager {
|
||||
return &Manager{
|
||||
sharedBind: sharedBind,
|
||||
newtID: newtID,
|
||||
}
|
||||
}
|
||||
|
||||
// SetToken updates the authentication token used for hole punching
|
||||
func (m *Manager) SetToken(token string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.token = token
|
||||
}
|
||||
|
||||
// IsRunning returns whether hole punching is currently active
|
||||
func (m *Manager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.running
|
||||
}
|
||||
|
||||
// Stop stops any ongoing hole punch operations
|
||||
func (m *Manager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
|
||||
if m.stopChan != nil {
|
||||
close(m.stopChan)
|
||||
m.stopChan = nil
|
||||
}
|
||||
|
||||
m.running = false
|
||||
logger.Info("Hole punch manager stopped")
|
||||
}
|
||||
|
||||
// StartMultipleExitNodes starts hole punching to multiple exit nodes
|
||||
func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
if len(exitNodes) == 0 {
|
||||
m.mu.Unlock()
|
||||
logger.Warn("No exit nodes provided for hole punching")
|
||||
return fmt.Errorf("no exit nodes provided")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes))
|
||||
|
||||
go m.runMultipleExitNodes(exitNodes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode)
|
||||
func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
if m.running {
|
||||
m.mu.Unlock()
|
||||
logger.Debug("UDP hole punch already running, skipping new request")
|
||||
return fmt.Errorf("hole punch already running")
|
||||
}
|
||||
|
||||
m.running = true
|
||||
m.stopChan = make(chan struct{})
|
||||
m.mu.Unlock()
|
||||
|
||||
logger.Info("Starting UDP hole punch to %s with shared bind", endpoint)
|
||||
|
||||
go m.runSingleEndpoint(endpoint, serverPubKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// runMultipleExitNodes performs hole punching to multiple exit nodes
|
||||
func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for all exit nodes")
|
||||
}()
|
||||
|
||||
// Resolve all endpoints upfront
|
||||
type resolvedExitNode struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
publicKey string
|
||||
endpointName string
|
||||
}
|
||||
|
||||
var resolvedNodes []resolvedExitNode
|
||||
for _, exitNode := range exitNodes {
|
||||
host, err := util.ResolveDomain(exitNode.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err)
|
||||
continue
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resolvedNodes = append(resolvedNodes, resolvedExitNode{
|
||||
remoteAddr: remoteAddr,
|
||||
publicKey: exitNode.PublicKey,
|
||||
endpointName: exitNode.Endpoint,
|
||||
})
|
||||
logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String())
|
||||
}
|
||||
|
||||
if len(resolvedNodes) == 0 {
|
||||
logger.Error("No exit nodes could be resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Send initial hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Debug("Hole punch timeout reached")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// Send hole punch to all exit nodes
|
||||
for _, node := range resolvedNodes {
|
||||
if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runSingleEndpoint performs hole punching to a single endpoint
|
||||
func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) {
|
||||
defer func() {
|
||||
m.mu.Lock()
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
logger.Info("UDP hole punch goroutine ended for %s", endpoint)
|
||||
}()
|
||||
|
||||
host, err := util.ResolveDomain(endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve domain %s: %v", endpoint, err)
|
||||
return
|
||||
}
|
||||
|
||||
serverAddr := net.JoinHostPort(host, "21820")
|
||||
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Execute once immediately before starting the loop
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Warn("Failed to send initial hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopChan:
|
||||
logger.Debug("Hole punch stopped by signal")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Debug("Hole punch timeout reached")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil {
|
||||
logger.Debug("Failed to send hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHolePunch sends an encrypted hole punch packet using the shared bind
|
||||
func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error {
|
||||
m.mu.Lock()
|
||||
token := m.token
|
||||
newtID := m.newtID
|
||||
m.mu.Unlock()
|
||||
|
||||
if serverPubKey == "" || token == "" {
|
||||
return fmt.Errorf("server public key or OLM token is empty")
|
||||
}
|
||||
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
NewtID: newtID,
|
||||
Token: token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %w", err)
|
||||
}
|
||||
|
||||
_, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write to UDP: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange
|
||||
func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(serverPublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
3
main.go
3
main.go
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/updates"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
|
||||
"github.com/fosrl/newt/internal/state"
|
||||
@@ -663,7 +664,7 @@ func main() {
|
||||
|
||||
logger.Info("Connecting to endpoint: %s", host)
|
||||
|
||||
endpoint, err := resolveDomain(wgData.Endpoint)
|
||||
endpoint, err := util.ResolveDomain(wgData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Error("Failed to resolve endpoint: %v", err)
|
||||
regResult = "failure"
|
||||
|
||||
58
util/util.go
Normal file
58
util/util.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ResolveDomain(domain string) (string, error) {
|
||||
// Check if there's a port in the domain
|
||||
host, port, err := net.SplitHostPort(domain)
|
||||
if err != nil {
|
||||
// No port found, use the domain as is
|
||||
host = domain
|
||||
port = ""
|
||||
}
|
||||
|
||||
// Remove any protocol prefix if present
|
||||
if strings.HasPrefix(host, "http://") {
|
||||
host = strings.TrimPrefix(host, "http://")
|
||||
} else if strings.HasPrefix(host, "https://") {
|
||||
host = strings.TrimPrefix(host, "https://")
|
||||
}
|
||||
|
||||
// if there are any trailing slashes, remove them
|
||||
host = strings.TrimSuffix(host, "/")
|
||||
|
||||
// Lookup IP addresses
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||
}
|
||||
|
||||
// Get the first IPv4 address if available
|
||||
var ipAddr string
|
||||
for _, ip := range ips {
|
||||
if ipv4 := ip.To4(); ipv4 != nil {
|
||||
ipAddr = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no IPv4 found, use the first IP (might be IPv6)
|
||||
if ipAddr == "" {
|
||||
ipAddr = ips[0].String()
|
||||
}
|
||||
|
||||
// Add port back if it existed
|
||||
if port != "" {
|
||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||
}
|
||||
|
||||
return ipAddr, nil
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package wgnetstack
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
@@ -16,14 +15,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/holepunch"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/netstack2"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/proxy"
|
||||
"github.com/fosrl/newt/websocket"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -66,22 +63,20 @@ type PeerReading struct {
|
||||
}
|
||||
|
||||
type WireGuardService struct {
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
keyFilePath string
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
stopHolepunch chan struct{}
|
||||
host string
|
||||
serverPubKey string
|
||||
holePunchEndpoint string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
interfaceName string
|
||||
mtu int
|
||||
client *websocket.Client
|
||||
config WgConfig
|
||||
key wgtypes.Key
|
||||
keyFilePath string
|
||||
newtId string
|
||||
lastReadings map[string]PeerReading
|
||||
mu sync.Mutex
|
||||
Port uint16
|
||||
host string
|
||||
serverPubKey string
|
||||
token string
|
||||
stopGetConfig func()
|
||||
// Netstack fields
|
||||
tun tun.Device
|
||||
tnet *netstack2.Net
|
||||
@@ -95,6 +90,9 @@ type WireGuardService struct {
|
||||
// Proxy manager for tunnel
|
||||
proxyManager *proxy.ProxyManager
|
||||
TunnelIP string
|
||||
// Shared bind and holepunch manager
|
||||
sharedBind *bind.SharedBind
|
||||
holePunchManager *holepunch.Manager
|
||||
}
|
||||
|
||||
// GetProxyManager returns the proxy manager for this WireGuardService
|
||||
@@ -118,24 +116,6 @@ func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) e
|
||||
return s.proxyManager.RemoveTarget(proto, listenIP, port)
|
||||
}
|
||||
|
||||
// Add this type definition
|
||||
type fixedPortBind struct {
|
||||
port uint16
|
||||
conn.Bind
|
||||
}
|
||||
|
||||
func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
// Ignore the requested port and use our fixed port
|
||||
return b.Bind.Open(b.port)
|
||||
}
|
||||
|
||||
func NewFixedPortBind(port uint16) conn.Bind {
|
||||
return &fixedPortBind{
|
||||
port: port,
|
||||
Bind: conn.NewDefaultBind(),
|
||||
}
|
||||
}
|
||||
|
||||
// find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
@@ -215,6 +195,28 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
||||
return nil, fmt.Errorf("error finding available port: %v", err)
|
||||
}
|
||||
|
||||
// Create shared UDP socket for both holepunch and WireGuard
|
||||
localAddr := &net.UDPAddr{
|
||||
Port: int(port),
|
||||
IP: net.IPv4zero,
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create UDP socket: %v", err)
|
||||
}
|
||||
|
||||
sharedBind, err := bind.New(udpConn)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
return nil, fmt.Errorf("failed to create shared bind: %v", err)
|
||||
}
|
||||
|
||||
// Add a reference for the hole punch manager (creator already has one reference for WireGuard)
|
||||
sharedBind.AddRef()
|
||||
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", port, sharedBind.GetRefCount())
|
||||
|
||||
// Parse DNS addresses
|
||||
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
|
||||
|
||||
@@ -227,12 +229,16 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
||||
newtId: newtId,
|
||||
host: host,
|
||||
lastReadings: make(map[string]PeerReading),
|
||||
stopHolepunch: make(chan struct{}),
|
||||
Port: port,
|
||||
dns: dnsAddrs,
|
||||
proxyManager: proxy.NewProxyManagerWithoutTNet(),
|
||||
sharedBind: sharedBind,
|
||||
}
|
||||
|
||||
// Create the holepunch manager with ResolveDomain function
|
||||
// We'll need to pass a domain resolver function
|
||||
service.holePunchManager = holepunch.NewManager(sharedBind, newtId)
|
||||
|
||||
// Register websocket handlers
|
||||
wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig)
|
||||
wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer)
|
||||
@@ -344,10 +350,15 @@ func (s *WireGuardService) Close(rm bool) {
|
||||
s.stopGetConfig = nil
|
||||
}
|
||||
|
||||
// Stop hole punch manager
|
||||
if s.holePunchManager != nil {
|
||||
s.holePunchManager.Stop()
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Close WireGuard device first - this will automatically close the TUN device
|
||||
// Close WireGuard device first - this will call sharedBind.Close() which releases WireGuard's reference
|
||||
if s.device != nil {
|
||||
s.device.Close()
|
||||
s.device = nil
|
||||
@@ -360,28 +371,22 @@ func (s *WireGuardService) Close(rm bool) {
|
||||
if s.tun != nil {
|
||||
s.tun = nil // Don't call tun.Close() here since device.Close() already closed it
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) {
|
||||
// if the device is already created dont start a new holepunch
|
||||
if s.device != nil {
|
||||
return
|
||||
// Release the hole punch reference to the shared bind
|
||||
if s.sharedBind != nil {
|
||||
// Release hole punch reference (WireGuard already released its reference via device.Close())
|
||||
logger.Debug("Releasing shared bind (refcount before release: %d)", s.sharedBind.GetRefCount())
|
||||
s.sharedBind.Release()
|
||||
s.sharedBind = nil
|
||||
logger.Info("Released shared UDP bind")
|
||||
}
|
||||
|
||||
s.serverPubKey = serverPubKey
|
||||
s.holePunchEndpoint = endpoint
|
||||
|
||||
logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint)
|
||||
|
||||
// Create a new stop channel for this holepunch session
|
||||
s.stopHolepunch = make(chan struct{})
|
||||
|
||||
// start the UDP holepunch
|
||||
go s.keepSendingUDPHolePunch(s.holePunchEndpoint)
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetToken(token string) {
|
||||
s.token = token
|
||||
if s.holePunchManager != nil {
|
||||
s.holePunchManager.SetToken(token)
|
||||
}
|
||||
}
|
||||
|
||||
// GetNetstackNet returns the netstack network interface for use by other components
|
||||
@@ -412,6 +417,19 @@ func (s *WireGuardService) SetOnNetstackClose(callback func()) {
|
||||
s.onNetstackClose = callback
|
||||
}
|
||||
|
||||
// StartHolepunch starts hole punching to a specific endpoint
|
||||
func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
|
||||
if s.holePunchManager == nil {
|
||||
logger.Warn("Hole punch manager not initialized")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey)
|
||||
if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
if s.stopGetConfig != nil {
|
||||
s.stopGetConfig()
|
||||
@@ -485,10 +503,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// Parse the IP address and CIDR mask
|
||||
tunnelIP := netip.MustParseAddr(parts[0])
|
||||
|
||||
// stop the holepunch its a channel
|
||||
if s.stopHolepunch != nil {
|
||||
close(s.stopHolepunch)
|
||||
s.stopHolepunch = nil
|
||||
// Stop any ongoing hole punch operations
|
||||
if s.holePunchManager != nil {
|
||||
s.holePunchManager.Stop()
|
||||
}
|
||||
|
||||
// Parse the IP address from the config
|
||||
@@ -512,8 +529,8 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
// s.proxyManager.SetTNet(s.tnet)
|
||||
s.TunnelIP = tunnelIP.String()
|
||||
|
||||
// Create WireGuard device
|
||||
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
|
||||
// Create WireGuard device using the shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent, // Use silent logging by default - could be made configurable
|
||||
"wireguard: ",
|
||||
))
|
||||
@@ -946,171 +963,6 @@ func (s *WireGuardService) reportPeerBandwidth() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error {
|
||||
|
||||
if s.serverPubKey == "" || s.token == "" {
|
||||
logger.Debug("Server public key or token not set, skipping UDP hole punch")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse server address
|
||||
serverSplit := strings.Split(serverAddr, ":")
|
||||
if len(serverSplit) < 2 {
|
||||
return fmt.Errorf("invalid server address format, expected hostname:port")
|
||||
}
|
||||
|
||||
serverHostname := serverSplit[0]
|
||||
serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse server port: %v", err)
|
||||
}
|
||||
|
||||
// Resolve server hostname to IP
|
||||
serverIPAddr := network.HostToAddr(serverHostname)
|
||||
if serverIPAddr == nil {
|
||||
return fmt.Errorf("failed to resolve server hostname")
|
||||
}
|
||||
|
||||
// Create local UDP address using the same port as WireGuard
|
||||
localAddr := &net.UDPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: int(s.Port),
|
||||
}
|
||||
|
||||
// Create remote server address
|
||||
remoteAddr := &net.UDPAddr{
|
||||
IP: serverIPAddr.IP,
|
||||
Port: int(serverPort),
|
||||
}
|
||||
|
||||
// Create UDP connection bound to the same port as WireGuard
|
||||
conn, err := net.DialUDP("udp", localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create netstack UDP connection: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Create JSON payload
|
||||
payload := struct {
|
||||
NewtID string `json:"newtId"`
|
||||
Token string `json:"token"`
|
||||
}{
|
||||
NewtID: s.newtId,
|
||||
Token: s.token,
|
||||
}
|
||||
|
||||
// Convert payload to JSON
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload using the server's WireGuard public key
|
||||
encryptedPayload, err := s.encryptPayload(payloadBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt payload: %v", err)
|
||||
}
|
||||
|
||||
// Convert encrypted payload to JSON
|
||||
jsonData, err := json.Marshal(encryptedPayload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal encrypted payload: %v", err)
|
||||
}
|
||||
|
||||
// Send the encrypted packet using the netstack UDP connection
|
||||
_, err = conn.Write(jsonData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send UDP packet: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) {
|
||||
// Generate an ephemeral keypair for this message
|
||||
ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err)
|
||||
}
|
||||
ephemeralPublicKey := ephemeralPrivateKey.PublicKey()
|
||||
|
||||
// Parse the server's public key
|
||||
serverPubKey, err := wgtypes.ParseKey(s.serverPubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse server public key: %v", err)
|
||||
}
|
||||
|
||||
// Use X25519 for key exchange (replacing deprecated ScalarMult)
|
||||
var ephPrivKeyFixed [32]byte
|
||||
copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:])
|
||||
|
||||
// Perform X25519 key exchange
|
||||
sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err)
|
||||
}
|
||||
|
||||
// Create an AEAD cipher using the shared secret
|
||||
aead, err := chacha20poly1305.New(sharedSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AEAD cipher: %v", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %v", err)
|
||||
}
|
||||
|
||||
// Encrypt the payload
|
||||
ciphertext := aead.Seal(nil, nonce, payload, nil)
|
||||
|
||||
// Prepare the final encrypted message
|
||||
encryptedMsg := struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}{
|
||||
EphemeralPublicKey: ephemeralPublicKey.String(),
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
}
|
||||
|
||||
return encryptedMsg, nil
|
||||
}
|
||||
|
||||
func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
||||
logger.Info("Starting UDP hole punch routine to %s:21820", host)
|
||||
|
||||
// send initial hole punch
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Debug("Failed to send initial UDP hole punch: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.NewTimer(15 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.stopHolepunch:
|
||||
logger.Info("Stopping UDP holepunch")
|
||||
return
|
||||
case <-timeout.C:
|
||||
logger.Info("UDP holepunch routine timed out after 15 seconds")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.sendUDPHolePunch(host + ":21820"); err != nil {
|
||||
logger.Debug("Failed to send UDP hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||
var replace = false
|
||||
for _, t := range targetData.Targets {
|
||||
@@ -1242,8 +1094,8 @@ func (s *WireGuardService) ReplaceNetstack() error {
|
||||
s.tun = newTun
|
||||
s.tnet = newTnet
|
||||
|
||||
// Create new WireGuard device with same port
|
||||
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
|
||||
// Create new WireGuard device with same shared bind
|
||||
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||
device.LogLevelSilent,
|
||||
"wireguard: ",
|
||||
))
|
||||
|
||||
Reference in New Issue
Block a user