diff --git a/bind/shared_bind.go b/bind/shared_bind.go new file mode 100644 index 0000000..bff66bf --- /dev/null +++ b/bind/shared_bind.go @@ -0,0 +1,378 @@ +//go:build !js + +package bind + +import ( + "fmt" + "net" + "net/netip" + "runtime" + "sync" + "sync/atomic" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + wgConn "golang.zx2c4.com/wireguard/conn" +) + +// Endpoint represents a network endpoint for the SharedBind +type Endpoint struct { + AddrPort netip.AddrPort +} + +// ClearSrc implements the wgConn.Endpoint interface +func (e *Endpoint) ClearSrc() {} + +// DstIP implements the wgConn.Endpoint interface +func (e *Endpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +// SrcIP implements the wgConn.Endpoint interface +func (e *Endpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +// DstToBytes implements the wgConn.Endpoint interface +func (e *Endpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +// DstToString implements the wgConn.Endpoint interface +func (e *Endpoint) DstToString() string { + return e.AddrPort.String() +} + +// SrcToString implements the wgConn.Endpoint interface +func (e *Endpoint) SrcToString() string { + return "" +} + +// SharedBind is a thread-safe UDP bind that can be shared between WireGuard +// and hole punch senders. It wraps a single UDP connection and implements +// reference counting to prevent premature closure. +type SharedBind struct { + mu sync.RWMutex + + // The underlying UDP connection + udpConn *net.UDPConn + + // IPv4 and IPv6 packet connections for advanced features + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + + // Reference counting to prevent closing while in use + refCount atomic.Int32 + closed atomic.Bool + + // Channels for receiving data + recvFuncs []wgConn.ReceiveFunc + + // Port binding information + port uint16 +} + +// New creates a new SharedBind from an existing UDP connection. +// The SharedBind takes ownership of the connection and will close it +// when all references are released. +func New(udpConn *net.UDPConn) (*SharedBind, error) { + if udpConn == nil { + return nil, fmt.Errorf("udpConn cannot be nil") + } + + bind := &SharedBind{ + udpConn: udpConn, + } + + // Initialize reference count to 1 (the creator holds the first reference) + bind.refCount.Store(1) + + // Get the local port + if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok { + bind.port = uint16(addr.Port) + } + + return bind, nil +} + +// AddRef increments the reference count. Call this when sharing +// the bind with another component. +func (b *SharedBind) AddRef() { + newCount := b.refCount.Add(1) + // Optional: Add logging for debugging + _ = newCount // Placeholder for potential logging +} + +// Release decrements the reference count. When it reaches zero, +// the underlying UDP connection is closed. +func (b *SharedBind) Release() error { + newCount := b.refCount.Add(-1) + // Optional: Add logging for debugging + _ = newCount // Placeholder for potential logging + + if newCount < 0 { + // This should never happen with proper usage + b.refCount.Store(0) + return fmt.Errorf("SharedBind reference count went negative") + } + + if newCount == 0 { + return b.closeConnection() + } + + return nil +} + +// closeConnection actually closes the UDP connection +func (b *SharedBind) closeConnection() error { + if !b.closed.CompareAndSwap(false, true) { + // Already closed + return nil + } + + b.mu.Lock() + defer b.mu.Unlock() + + var err error + if b.udpConn != nil { + err = b.udpConn.Close() + b.udpConn = nil + } + + b.ipv4PC = nil + b.ipv6PC = nil + + return err +} + +// GetUDPConn returns the underlying UDP connection. +// The caller must not close this connection directly. +func (b *SharedBind) GetUDPConn() *net.UDPConn { + b.mu.RLock() + defer b.mu.RUnlock() + return b.udpConn +} + +// GetRefCount returns the current reference count (for debugging) +func (b *SharedBind) GetRefCount() int32 { + return b.refCount.Load() +} + +// IsClosed returns whether the bind is closed +func (b *SharedBind) IsClosed() bool { + return b.closed.Load() +} + +// WriteToUDP writes data to a specific UDP address. +// This is thread-safe and can be used by hole punch senders. +func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) { + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + return conn.WriteToUDP(data, addr) +} + +// Close implements the WireGuard Bind interface. +// It decrements the reference count and closes the connection if no references remain. +func (b *SharedBind) Close() error { + return b.Release() +} + +// Open implements the WireGuard Bind interface. +// Since the connection is already open, this just sets up the receive functions. +func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + if b.closed.Load() { + return nil, 0, net.ErrClosed + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.udpConn == nil { + return nil, 0, net.ErrClosed + } + + // Set up IPv4 and IPv6 packet connections for advanced features + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + b.ipv4PC = ipv4.NewPacketConn(b.udpConn) + b.ipv6PC = ipv6.NewPacketConn(b.udpConn) + } + + // Create receive functions + recvFuncs := make([]wgConn.ReceiveFunc, 0, 2) + + // Add IPv4 receive function + if b.ipv4PC != nil || runtime.GOOS != "linux" { + recvFuncs = append(recvFuncs, b.makeReceiveIPv4()) + } + + // Add IPv6 receive function if needed + // For now, we focus on IPv4 for hole punching use case + + b.recvFuncs = recvFuncs + return recvFuncs, b.port, nil +} + +// makeReceiveIPv4 creates a receive function for IPv4 packets +func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + pc := b.ipv4PC + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + // Use batch reading on Linux for performance + if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + return b.receiveIPv4Batch(pc, bufs, sizes, eps) + } + + // Fallback to simple read for other platforms + return b.receiveIPv4Simple(conn, bufs, sizes, eps) + } +} + +// receiveIPv4Batch uses batch reading for better performance on Linux +func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + // Create messages for batch reading + msgs := make([]ipv4.Message, len(bufs)) + for i := range bufs { + msgs[i].Buffers = [][]byte{bufs[i]} + msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use + } + + numMsgs, err := pc.ReadBatch(msgs, 0) + if err != nil { + return 0, err + } + + for i := 0; i < numMsgs; i++ { + sizes[i] = msgs[i].N + if sizes[i] == 0 { + continue + } + + if msgs[i].Addr != nil { + if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok { + addrPort := udpAddr.AddrPort() + eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + } + } + + return numMsgs, nil +} + +// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms +func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + n, addr, err := conn.ReadFromUDP(bufs[0]) + if err != nil { + return 0, err + } + + sizes[0] = n + if addr != nil { + addrPort := addr.AddrPort() + eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + + return 1, nil +} + +// Send implements the WireGuard Bind interface. +// It sends packets to the specified endpoint. +func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { + if b.closed.Load() { + return net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return net.ErrClosed + } + + // Extract the destination address from the endpoint + var destAddr *net.UDPAddr + + // Try to cast to StdNetEndpoint first + if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { + destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort) + } else { + // Fallback: construct from DstIP and DstToBytes + dstBytes := ep.DstToBytes() + if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes) + var addr netip.Addr + var port uint16 + + if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes) + addr, _ = netip.AddrFromSlice(dstBytes[:16]) + port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8 + } else { // IPv4 + addr, _ = netip.AddrFromSlice(dstBytes[:4]) + port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8 + } + + if addr.IsValid() { + destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) + } + } + } + + if destAddr == nil { + return fmt.Errorf("could not extract destination address from endpoint") + } + + // Send all buffers to the destination + for _, buf := range bufs { + _, err := conn.WriteToUDP(buf, destAddr) + if err != nil { + return err + } + } + + return nil +} + +// SetMark implements the WireGuard Bind interface. +// It's a no-op for this implementation. +func (b *SharedBind) SetMark(mark uint32) error { + // Not implemented for this use case + return nil +} + +// BatchSize returns the preferred batch size for sending packets. +func (b *SharedBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return wgConn.IdealBatchSize + } + return 1 +} + +// ParseEndpoint creates a new endpoint from a string address. +func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { + addrPort, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil +} diff --git a/bind/shared_bind_test.go b/bind/shared_bind_test.go new file mode 100644 index 0000000..6e1ec66 --- /dev/null +++ b/bind/shared_bind_test.go @@ -0,0 +1,424 @@ +//go:build !js + +package bind + +import ( + "net" + "net/netip" + "sync" + "testing" + "time" + + wgConn "golang.zx2c4.com/wireguard/conn" +) + +// TestSharedBindCreation tests basic creation and initialization +func TestSharedBindCreation(t *testing.T) { + // Create a UDP connection + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + defer udpConn.Close() + + // Create SharedBind + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + if bind == nil { + t.Fatal("SharedBind is nil") + } + + // Verify initial reference count + if bind.refCount.Load() != 1 { + t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load()) + } + + // Clean up + if err := bind.Close(); err != nil { + t.Errorf("Failed to close SharedBind: %v", err) + } +} + +// TestSharedBindReferenceCount tests reference counting +func TestSharedBindReferenceCount(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + // Add references + bind.AddRef() + if bind.refCount.Load() != 2 { + t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load()) + } + + bind.AddRef() + if bind.refCount.Load() != 3 { + t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load()) + } + + // Release references + bind.Release() + if bind.refCount.Load() != 2 { + t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load()) + } + + bind.Release() + bind.Release() // This should close the connection + + if !bind.closed.Load() { + t.Error("Expected bind to be closed after all references released") + } +} + +// TestSharedBindWriteToUDP tests the WriteToUDP functionality +func TestSharedBindWriteToUDP(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Send data + testData := []byte("Hello, SharedBind!") + n, err := senderBind.WriteToUDP(testData, receiverAddr) + if err != nil { + t.Fatalf("WriteToUDP failed: %v", err) + } + + if n != len(testData) { + t.Errorf("Expected to send %d bytes, sent %d", len(testData), n) + } + + // Receive data + buf := make([]byte, 1024) + receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, _, err = receiverConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive data: %v", err) + } + + if string(buf[:n]) != string(testData) { + t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) + } +} + +// TestSharedBindConcurrentWrites tests thread-safety +func TestSharedBindConcurrentWrites(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Launch concurrent writes + numGoroutines := 100 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + data := []byte{byte(id)} + _, err := senderBind.WriteToUDP(data, receiverAddr) + if err != nil { + t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err) + } + }(i) + } + + wg.Wait() +} + +// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation +func TestSharedBindWireGuardInterface(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer bind.Close() + + // Test Open + recvFuncs, port, err := bind.Open(0) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if len(recvFuncs) == 0 { + t.Error("Expected at least one receive function") + } + + if port == 0 { + t.Error("Expected non-zero port") + } + + // Test SetMark (should be a no-op) + if err := bind.SetMark(0); err != nil { + t.Errorf("SetMark failed: %v", err) + } + + // Test BatchSize + batchSize := bind.BatchSize() + if batchSize <= 0 { + t.Error("Expected positive batch size") + } +} + +// TestSharedBindSend tests the Send method with WireGuard endpoints +func TestSharedBindSend(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Create an endpoint + addrPort := receiverAddr.AddrPort() + endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} + + // Send data + testData := []byte("WireGuard packet") + bufs := [][]byte{testData} + err = senderBind.Send(bufs, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // Receive data + buf := make([]byte, 1024) + receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, _, err := receiverConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive data: %v", err) + } + + if string(buf[:n]) != string(testData) { + t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) + } +} + +// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind +func TestSharedBindMultipleUsers(t *testing.T) { + // Create shared bind + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + sharedBind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + // Add reference for hole punch sender + sharedBind.AddRef() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + var wg sync.WaitGroup + + // Simulate WireGuard using the bind + wg.Add(1) + go func() { + defer wg.Done() + addrPort := receiverAddr.AddrPort() + endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} + + for i := 0; i < 10; i++ { + data := []byte("WireGuard packet") + bufs := [][]byte{data} + if err := sharedBind.Send(bufs, endpoint); err != nil { + t.Errorf("WireGuard Send failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + + // Simulate hole punch sender using the bind + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + data := []byte("Hole punch packet") + if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil { + t.Errorf("Hole punch WriteToUDP failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + + wg.Wait() + + // Release the hole punch reference + sharedBind.Release() + + // Close WireGuard's reference (should close the connection) + sharedBind.Close() + + if !sharedBind.closed.Load() { + t.Error("Expected bind to be closed after all users released it") + } +} + +// TestEndpoint tests the Endpoint implementation +func TestEndpoint(t *testing.T) { + addr := netip.MustParseAddr("192.168.1.1") + addrPort := netip.AddrPortFrom(addr, 51820) + + ep := &Endpoint{AddrPort: addrPort} + + // Test DstIP + if ep.DstIP() != addr { + t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP()) + } + + // Test DstToString + expected := "192.168.1.1:51820" + if ep.DstToString() != expected { + t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString()) + } + + // Test DstToBytes + bytes := ep.DstToBytes() + if len(bytes) == 0 { + t.Error("Expected DstToBytes to return non-empty slice") + } + + // Test SrcIP (should be zero) + if ep.SrcIP().IsValid() { + t.Error("Expected SrcIP to be invalid") + } + + // Test ClearSrc (should not panic) + ep.ClearSrc() +} + +// TestParseEndpoint tests the ParseEndpoint method +func TestParseEndpoint(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer bind.Close() + + tests := []struct { + name string + input string + wantErr bool + checkAddr func(*testing.T, wgConn.Endpoint) + }{ + { + name: "valid IPv4", + input: "192.168.1.1:51820", + wantErr: false, + checkAddr: func(t *testing.T, ep wgConn.Endpoint) { + if ep.DstToString() != "192.168.1.1:51820" { + t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString()) + } + }, + }, + { + name: "valid IPv6", + input: "[::1]:51820", + wantErr: false, + checkAddr: func(t *testing.T, ep wgConn.Endpoint) { + if ep.DstToString() != "[::1]:51820" { + t.Errorf("Expected [::1]:51820, got %s", ep.DstToString()) + } + }, + }, + { + name: "invalid - missing port", + input: "192.168.1.1", + wantErr: true, + }, + { + name: "invalid - bad format", + input: "not-an-address", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep, err := bind.ParseEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.checkAddr != nil { + tt.checkAddr(t, ep) + } + }) + } +} diff --git a/util.go b/common.go similarity index 93% rename from util.go rename to common.go index dc48f19..454283a 100644 --- a/util.go +++ b/common.go @@ -7,7 +7,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "net" "os" "os/exec" "strings" @@ -398,57 +397,6 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int { } } -func resolveDomain(domain string) (string, error) { - // Check if there's a port in the domain - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Remove any protocol prefix if present - if strings.HasPrefix(host, "http://") { - host = strings.TrimPrefix(host, "http://") - } else if strings.HasPrefix(host, "https://") { - host = strings.TrimPrefix(host, "https://") - } - - // if there are any trailing slashes, remove them - host = strings.TrimSuffix(host, "/") - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil -} - func parseTargetData(data interface{}) (TargetData, error) { var targetData TargetData jsonData, err := json.Marshal(data) diff --git a/go.mod b/go.mod index 5a930b6..32c1ae3 100644 --- a/go.mod +++ b/go.mod @@ -17,9 +17,9 @@ require ( go.opentelemetry.io/otel/metric v1.38.0 go.opentelemetry.io/otel/sdk v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 - golang.org/x/crypto v0.43.0 - golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 - golang.org/x/net v0.46.0 + golang.org/x/crypto v0.44.0 + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 + golang.org/x/net v0.47.0 golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 google.golang.org/grpc v1.76.0 @@ -69,12 +69,12 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 // indirect go.opentelemetry.io/proto/otlp v1.7.1 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/mod v0.28.0 // indirect - golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/mod v0.30.0 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.12.0 // indirect - golang.org/x/tools v0.37.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect diff --git a/go.sum b/go.sum index 81cbe33..d322b92 100644 --- a/go.sum +++ b/go.sum @@ -107,32 +107,47 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= +golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.28.0 h1:gQBtGhjxykdjY9YhZpSlZIsbnaE2+PgjfLWUQTnoZ1U= golang.org/x/mod v0.28.0/go.mod h1:yfB/L0NOf/kmEbXjzCPOx1iK1fRutOydrCMsqRhEBxI= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.37.0 h1:DVSRzp7FwePZW356yEAChSdNcQo6Nsp+fex1SUW09lE= golang.org/x/tools v0.37.0/go.mod h1:MBN5QPQtLMHVdvsbtarmTNukZDdgwdwlO5qGacAzF0w= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go new file mode 100644 index 0000000..dfe9c74 --- /dev/null +++ b/holepunch/holepunch.go @@ -0,0 +1,347 @@ +package holepunch + +import ( + "encoding/json" + "fmt" + "net" + "sync" + "time" + + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/curve25519" + "golang.org/x/exp/rand" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// ExitNode represents a WireGuard exit node for hole punching +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + +// Manager handles UDP hole punching operations +type Manager struct { + mu sync.Mutex + running bool + stopChan chan struct{} + sharedBind *bind.SharedBind + newtID string + token string +} + +// NewManager creates a new hole punch manager +func NewManager(sharedBind *bind.SharedBind, newtID string) *Manager { + return &Manager{ + sharedBind: sharedBind, + newtID: newtID, + } +} + +// SetToken updates the authentication token used for hole punching +func (m *Manager) SetToken(token string) { + m.mu.Lock() + defer m.mu.Unlock() + m.token = token +} + +// IsRunning returns whether hole punching is currently active +func (m *Manager) IsRunning() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.running +} + +// Stop stops any ongoing hole punch operations +func (m *Manager) Stop() { + m.mu.Lock() + defer m.mu.Unlock() + + if !m.running { + return + } + + if m.stopChan != nil { + close(m.stopChan) + m.stopChan = nil + } + + m.running = false + logger.Info("Hole punch manager stopped") +} + +// StartMultipleExitNodes starts hole punching to multiple exit nodes +func (m *Manager) StartMultipleExitNodes(exitNodes []ExitNode) error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running, skipping new request") + return fmt.Errorf("hole punch already running") + } + + if len(exitNodes) == 0 { + m.mu.Unlock() + logger.Warn("No exit nodes provided for hole punching") + return fmt.Errorf("no exit nodes provided") + } + + m.running = true + m.stopChan = make(chan struct{}) + m.mu.Unlock() + + logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) + + go m.runMultipleExitNodes(exitNodes) + + return nil +} + +// StartSingleEndpoint starts hole punching to a single endpoint (legacy mode) +func (m *Manager) StartSingleEndpoint(endpoint, serverPubKey string) error { + m.mu.Lock() + + if m.running { + m.mu.Unlock() + logger.Debug("UDP hole punch already running, skipping new request") + return fmt.Errorf("hole punch already running") + } + + m.running = true + m.stopChan = make(chan struct{}) + m.mu.Unlock() + + logger.Info("Starting UDP hole punch to %s with shared bind", endpoint) + + go m.runSingleEndpoint(endpoint, serverPubKey) + + return nil +} + +// runMultipleExitNodes performs hole punching to multiple exit nodes +func (m *Manager) runMultipleExitNodes(exitNodes []ExitNode) { + defer func() { + m.mu.Lock() + m.running = false + m.mu.Unlock() + logger.Info("UDP hole punch goroutine ended for all exit nodes") + }() + + // Resolve all endpoints upfront + type resolvedExitNode struct { + remoteAddr *net.UDPAddr + publicKey string + endpointName string + } + + var resolvedNodes []resolvedExitNode + for _, exitNode := range exitNodes { + host, err := util.ResolveDomain(exitNode.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + continue + } + + resolvedNodes = append(resolvedNodes, resolvedExitNode{ + remoteAddr: remoteAddr, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, + }) + logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + } + + if len(resolvedNodes) == 0 { + logger.Error("No exit nodes could be resolved") + return + } + + // Send initial hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Warn("Failed to send initial hole punch to %s: %v", node.endpointName, err) + } + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + + for { + select { + case <-m.stopChan: + logger.Debug("Hole punch stopped by signal") + return + case <-timeout.C: + logger.Debug("Hole punch timeout reached") + return + case <-ticker.C: + // Send hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := m.sendHolePunch(node.remoteAddr, node.publicKey); err != nil { + logger.Debug("Failed to send hole punch to %s: %v", node.endpointName, err) + } + } + } + } +} + +// runSingleEndpoint performs hole punching to a single endpoint +func (m *Manager) runSingleEndpoint(endpoint, serverPubKey string) { + defer func() { + m.mu.Lock() + m.running = false + m.mu.Unlock() + logger.Info("UDP hole punch goroutine ended for %s", endpoint) + }() + + host, err := util.ResolveDomain(endpoint) + if err != nil { + logger.Error("Failed to resolve domain %s: %v", endpoint, err) + return + } + + serverAddr := net.JoinHostPort(host, "21820") + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + return + } + + // Execute once immediately before starting the loop + if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { + logger.Warn("Failed to send initial hole punch: %v", err) + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + + for { + select { + case <-m.stopChan: + logger.Debug("Hole punch stopped by signal") + return + case <-timeout.C: + logger.Debug("Hole punch timeout reached") + return + case <-ticker.C: + if err := m.sendHolePunch(remoteAddr, serverPubKey); err != nil { + logger.Debug("Failed to send hole punch: %v", err) + } + } + } +} + +// sendHolePunch sends an encrypted hole punch packet using the shared bind +func (m *Manager) sendHolePunch(remoteAddr *net.UDPAddr, serverPubKey string) error { + m.mu.Lock() + token := m.token + newtID := m.newtID + m.mu.Unlock() + + if serverPubKey == "" || token == "" { + return fmt.Errorf("server public key or OLM token is empty") + } + + payload := struct { + NewtID string `json:"newtId"` + Token string `json:"token"` + }{ + NewtID: newtID, + Token: token, + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %w", err) + } + + jsonData, err := json.Marshal(encryptedPayload) + if err != nil { + return fmt.Errorf("failed to marshal encrypted payload: %w", err) + } + + _, err = m.sharedBind.WriteToUDP(jsonData, remoteAddr) + if err != nil { + return fmt.Errorf("failed to write to UDP: %w", err) + } + + logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) + + return nil +} + +// encryptPayload encrypts the payload using ChaCha20-Poly1305 AEAD with X25519 key exchange +func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) { + // Generate an ephemeral keypair for this message + ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) + } + ephemeralPublicKey := ephemeralPrivateKey.PublicKey() + + // Parse the server's public key + serverPubKey, err := wgtypes.ParseKey(serverPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to parse server public key: %v", err) + } + + // Use X25519 for key exchange + var ephPrivKeyFixed [32]byte + copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) + + // Perform X25519 key exchange + sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) + if err != nil { + return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) + } + + // Create an AEAD cipher using the shared secret + aead, err := chacha20poly1305.New(sharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) + } + + // Generate a random nonce + nonce := make([]byte, aead.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("failed to generate nonce: %v", err) + } + + // Encrypt the payload + ciphertext := aead.Seal(nil, nonce, payload, nil) + + // Prepare the final encrypted message + encryptedMsg := struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` + }{ + EphemeralPublicKey: ephemeralPublicKey.String(), + Nonce: nonce, + Ciphertext: ciphertext, + } + + return encryptedMsg, nil +} diff --git a/main.go b/main.go index 57ac17c..0c625bb 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/updates" + "github.com/fosrl/newt/util" "github.com/fosrl/newt/websocket" "github.com/fosrl/newt/internal/state" @@ -663,7 +664,7 @@ func main() { logger.Info("Connecting to endpoint: %s", host) - endpoint, err := resolveDomain(wgData.Endpoint) + endpoint, err := util.ResolveDomain(wgData.Endpoint) if err != nil { logger.Error("Failed to resolve endpoint: %v", err) regResult = "failure" diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..79fbde3 --- /dev/null +++ b/util/util.go @@ -0,0 +1,58 @@ +package util + +import ( + "fmt" + "net" + "strings" +) + +func ResolveDomain(domain string) (string, error) { + // Check if there's a port in the domain + host, port, err := net.SplitHostPort(domain) + if err != nil { + // No port found, use the domain as is + host = domain + port = "" + } + + // Remove any protocol prefix if present + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + + // if there are any trailing slashes, remove them + host = strings.TrimSuffix(host, "/") + + // Lookup IP addresses + ips, err := net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("DNS lookup failed: %v", err) + } + + if len(ips) == 0 { + return "", fmt.Errorf("no IP addresses found for domain %s", host) + } + + // Get the first IPv4 address if available + var ipAddr string + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipAddr = ipv4.String() + break + } + } + + // If no IPv4 found, use the first IP (might be IPv6) + if ipAddr == "" { + ipAddr = ips[0].String() + } + + // Add port back if it existed + if port != "" { + ipAddr = net.JoinHostPort(ipAddr, port) + } + + return ipAddr, nil +} diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 63dcd1b..a376790 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -2,7 +2,6 @@ package wgnetstack import ( "context" - "crypto/rand" "encoding/base64" "encoding/hex" "encoding/json" @@ -16,14 +15,12 @@ import ( "sync" "time" + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/netstack2" - "github.com/fosrl/newt/network" "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/netstack" @@ -66,22 +63,20 @@ type PeerReading struct { } type WireGuardService struct { - interfaceName string - mtu int - client *websocket.Client - config WgConfig - key wgtypes.Key - keyFilePath string - newtId string - lastReadings map[string]PeerReading - mu sync.Mutex - Port uint16 - stopHolepunch chan struct{} - host string - serverPubKey string - holePunchEndpoint string - token string - stopGetConfig func() + interfaceName string + mtu int + client *websocket.Client + config WgConfig + key wgtypes.Key + keyFilePath string + newtId string + lastReadings map[string]PeerReading + mu sync.Mutex + Port uint16 + host string + serverPubKey string + token string + stopGetConfig func() // Netstack fields tun tun.Device tnet *netstack2.Net @@ -95,6 +90,9 @@ type WireGuardService struct { // Proxy manager for tunnel proxyManager *proxy.ProxyManager TunnelIP string + // Shared bind and holepunch manager + sharedBind *bind.SharedBind + holePunchManager *holepunch.Manager } // GetProxyManager returns the proxy manager for this WireGuardService @@ -118,24 +116,6 @@ func (s *WireGuardService) RemoveProxyTarget(proto, listenIP string, port int) e return s.proxyManager.RemoveTarget(proto, listenIP, port) } -// Add this type definition -type fixedPortBind struct { - port uint16 - conn.Bind -} - -func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - // Ignore the requested port and use our fixed port - return b.Bind.Open(b.port) -} - -func NewFixedPortBind(port uint16) conn.Bind { - return &fixedPortBind{ - port: port, - Bind: conn.NewDefaultBind(), - } -} - // find an available UDP port in the range [minPort, maxPort] and also the next port for the wgtester func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { if maxPort < minPort { @@ -215,6 +195,28 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str return nil, fmt.Errorf("error finding available port: %v", err) } + // Create shared UDP socket for both holepunch and WireGuard + localAddr := &net.UDPAddr{ + Port: int(port), + IP: net.IPv4zero, + } + + udpConn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return nil, fmt.Errorf("failed to create UDP socket: %v", err) + } + + sharedBind, err := bind.New(udpConn) + if err != nil { + udpConn.Close() + return nil, fmt.Errorf("failed to create shared bind: %v", err) + } + + // Add a reference for the hole punch manager (creator already has one reference for WireGuard) + sharedBind.AddRef() + + logger.Info("Created shared UDP socket on port %d (refcount: %d)", port, sharedBind.GetRefCount()) + // Parse DNS addresses dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)} @@ -227,12 +229,16 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str newtId: newtId, host: host, lastReadings: make(map[string]PeerReading), - stopHolepunch: make(chan struct{}), Port: port, dns: dnsAddrs, proxyManager: proxy.NewProxyManagerWithoutTNet(), + sharedBind: sharedBind, } + // Create the holepunch manager with ResolveDomain function + // We'll need to pass a domain resolver function + service.holePunchManager = holepunch.NewManager(sharedBind, newtId) + // Register websocket handlers wsClient.RegisterHandler("newt/wg/receive-config", service.handleConfig) wsClient.RegisterHandler("newt/wg/peer/add", service.handleAddPeer) @@ -344,10 +350,15 @@ func (s *WireGuardService) Close(rm bool) { s.stopGetConfig = nil } + // Stop hole punch manager + if s.holePunchManager != nil { + s.holePunchManager.Stop() + } + s.mu.Lock() defer s.mu.Unlock() - // Close WireGuard device first - this will automatically close the TUN device + // Close WireGuard device first - this will call sharedBind.Close() which releases WireGuard's reference if s.device != nil { s.device.Close() s.device = nil @@ -360,28 +371,22 @@ func (s *WireGuardService) Close(rm bool) { if s.tun != nil { s.tun = nil // Don't call tun.Close() here since device.Close() already closed it } -} -func (s *WireGuardService) StartHolepunch(serverPubKey string, endpoint string) { - // if the device is already created dont start a new holepunch - if s.device != nil { - return + // Release the hole punch reference to the shared bind + if s.sharedBind != nil { + // Release hole punch reference (WireGuard already released its reference via device.Close()) + logger.Debug("Releasing shared bind (refcount before release: %d)", s.sharedBind.GetRefCount()) + s.sharedBind.Release() + s.sharedBind = nil + logger.Info("Released shared UDP bind") } - - s.serverPubKey = serverPubKey - s.holePunchEndpoint = endpoint - - logger.Debug("Starting UDP hole punch to %s", s.holePunchEndpoint) - - // Create a new stop channel for this holepunch session - s.stopHolepunch = make(chan struct{}) - - // start the UDP holepunch - go s.keepSendingUDPHolePunch(s.holePunchEndpoint) } func (s *WireGuardService) SetToken(token string) { s.token = token + if s.holePunchManager != nil { + s.holePunchManager.SetToken(token) + } } // GetNetstackNet returns the netstack network interface for use by other components @@ -412,6 +417,19 @@ func (s *WireGuardService) SetOnNetstackClose(callback func()) { s.onNetstackClose = callback } +// StartHolepunch starts hole punching to a specific endpoint +func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) { + if s.holePunchManager == nil { + logger.Warn("Hole punch manager not initialized") + return + } + + logger.Info("Starting hole punch to %s with public key: %s", endpoint, publicKey) + if err := s.holePunchManager.StartSingleEndpoint(endpoint, publicKey); err != nil { + logger.Warn("Failed to start hole punch: %v", err) + } +} + func (s *WireGuardService) LoadRemoteConfig() error { if s.stopGetConfig != nil { s.stopGetConfig() @@ -485,10 +503,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // Parse the IP address and CIDR mask tunnelIP := netip.MustParseAddr(parts[0]) - // stop the holepunch its a channel - if s.stopHolepunch != nil { - close(s.stopHolepunch) - s.stopHolepunch = nil + // Stop any ongoing hole punch operations + if s.holePunchManager != nil { + s.holePunchManager.Stop() } // Parse the IP address from the config @@ -512,8 +529,8 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { // s.proxyManager.SetTNet(s.tnet) s.TunnelIP = tunnelIP.String() - // Create WireGuard device - s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger( + // Create WireGuard device using the shared bind + s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( device.LogLevelSilent, // Use silent logging by default - could be made configurable "wireguard: ", )) @@ -946,171 +963,6 @@ func (s *WireGuardService) reportPeerBandwidth() error { return nil } -func (s *WireGuardService) sendUDPHolePunch(serverAddr string) error { - - if s.serverPubKey == "" || s.token == "" { - logger.Debug("Server public key or token not set, skipping UDP hole punch") - return nil - } - - // Parse server address - serverSplit := strings.Split(serverAddr, ":") - if len(serverSplit) < 2 { - return fmt.Errorf("invalid server address format, expected hostname:port") - } - - serverHostname := serverSplit[0] - serverPort, err := strconv.ParseUint(serverSplit[1], 10, 16) - if err != nil { - return fmt.Errorf("failed to parse server port: %v", err) - } - - // Resolve server hostname to IP - serverIPAddr := network.HostToAddr(serverHostname) - if serverIPAddr == nil { - return fmt.Errorf("failed to resolve server hostname") - } - - // Create local UDP address using the same port as WireGuard - localAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: int(s.Port), - } - - // Create remote server address - remoteAddr := &net.UDPAddr{ - IP: serverIPAddr.IP, - Port: int(serverPort), - } - - // Create UDP connection bound to the same port as WireGuard - conn, err := net.DialUDP("udp", localAddr, remoteAddr) - if err != nil { - return fmt.Errorf("failed to create netstack UDP connection: %v", err) - } - defer conn.Close() - - // Create JSON payload - payload := struct { - NewtID string `json:"newtId"` - Token string `json:"token"` - }{ - NewtID: s.newtId, - Token: s.token, - } - - // Convert payload to JSON - payloadBytes, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := s.encryptPayload(payloadBytes) - if err != nil { - return fmt.Errorf("failed to encrypt payload: %v", err) - } - - // Convert encrypted payload to JSON - jsonData, err := json.Marshal(encryptedPayload) - if err != nil { - return fmt.Errorf("failed to marshal encrypted payload: %v", err) - } - - // Send the encrypted packet using the netstack UDP connection - _, err = conn.Write(jsonData) - if err != nil { - return fmt.Errorf("failed to send UDP packet: %v", err) - } - - logger.Debug("Sent UDP hole punch to %s via netstack", remoteAddr.String()) - - return nil -} - -func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { - // Generate an ephemeral keypair for this message - ephemeralPrivateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, fmt.Errorf("failed to generate ephemeral private key: %v", err) - } - ephemeralPublicKey := ephemeralPrivateKey.PublicKey() - - // Parse the server's public key - serverPubKey, err := wgtypes.ParseKey(s.serverPubKey) - if err != nil { - return nil, fmt.Errorf("failed to parse server public key: %v", err) - } - - // Use X25519 for key exchange (replacing deprecated ScalarMult) - var ephPrivKeyFixed [32]byte - copy(ephPrivKeyFixed[:], ephemeralPrivateKey[:]) - - // Perform X25519 key exchange - sharedSecret, err := curve25519.X25519(ephPrivKeyFixed[:], serverPubKey[:]) - if err != nil { - return nil, fmt.Errorf("failed to perform X25519 key exchange: %v", err) - } - - // Create an AEAD cipher using the shared secret - aead, err := chacha20poly1305.New(sharedSecret) - if err != nil { - return nil, fmt.Errorf("failed to create AEAD cipher: %v", err) - } - - // Generate a random nonce - nonce := make([]byte, aead.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("failed to generate nonce: %v", err) - } - - // Encrypt the payload - ciphertext := aead.Seal(nil, nonce, payload, nil) - - // Prepare the final encrypted message - encryptedMsg := struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` - }{ - EphemeralPublicKey: ephemeralPublicKey.String(), - Nonce: nonce, - Ciphertext: ciphertext, - } - - return encryptedMsg, nil -} - -func (s *WireGuardService) keepSendingUDPHolePunch(host string) { - logger.Info("Starting UDP hole punch routine to %s:21820", host) - - // send initial hole punch - if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Debug("Failed to send initial UDP hole punch: %v", err) - } - - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - - timeout := time.NewTimer(15 * time.Second) - defer timeout.Stop() - - for { - select { - case <-s.stopHolepunch: - logger.Info("Stopping UDP holepunch") - return - case <-timeout.C: - logger.Info("UDP holepunch routine timed out after 15 seconds") - return - case <-ticker.C: - if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Debug("Failed to send UDP hole punch: %v", err) - } - } - } -} - func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { var replace = false for _, t := range targetData.Targets { @@ -1242,8 +1094,8 @@ func (s *WireGuardService) ReplaceNetstack() error { s.tun = newTun s.tnet = newTnet - // Create new WireGuard device with same port - s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger( + // Create new WireGuard device with same shared bind + s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( device.LogLevelSilent, "wireguard: ", ))